Skip to content

Commit b6c40b3

Browse files
authored
feat(aggregation): Add getters and setters to all remaining aggregators (#660)
1 parent be1f194 commit b6c40b3

21 files changed

Lines changed: 681 additions & 141 deletions

CHANGELOG.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@ changelog does not include internal changes that do not affect the user.
1010

1111
### Added
1212

13-
- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
14-
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
15-
non-negative.
13+
- Added getters and setters for the constructor parameters of all aggregators and weightings, so
14+
that they can be changed after initialization. This includes: `pref_vector`,
15+
`norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`;
16+
`pref_vector` and `scale_mode` in `AlignedMTL` and `AlignedMTLWeighting`; `c` and `norm_eps` in
17+
`CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and
18+
`n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and
19+
`MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`;
20+
`trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor
21+
checks. Note that setters for `GradVac` and `GradVacWeighting` already existed.
1622

1723
## [0.10.0] - 2026-04-16
1824

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,25 @@ def __init__(
3535
scale_mode: SUPPORTED_SCALE_MODE = "min",
3636
) -> None:
3737
super().__init__()
38-
self._pref_vector = pref_vector
39-
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
40-
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
38+
self.pref_vector = pref_vector
39+
self.scale_mode: SUPPORTED_SCALE_MODE = scale_mode
4140

4241
def forward(self, gramian: PSDMatrix, /) -> Tensor:
4342
w = self.weighting(gramian)
44-
B = self._compute_balance_transformation(gramian, self._scale_mode)
43+
B = self._compute_balance_transformation(gramian, self.scale_mode)
4544
alpha = B @ w
4645

4746
return alpha
4847

48+
@property
49+
def pref_vector(self) -> Tensor | None:
50+
return self._pref_vector
51+
52+
@pref_vector.setter
53+
def pref_vector(self, value: Tensor | None) -> None:
54+
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
55+
self._pref_vector = value
56+
4957
@staticmethod
5058
def _compute_balance_transformation(
5159
M: Tensor,
@@ -103,15 +111,29 @@ def __init__(
103111
pref_vector: Tensor | None = None,
104112
scale_mode: SUPPORTED_SCALE_MODE = "min",
105113
) -> None:
106-
self._pref_vector = pref_vector
107-
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
108114
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
109115

116+
@property
117+
def pref_vector(self) -> Tensor | None:
118+
return self.gramian_weighting.pref_vector
119+
120+
@pref_vector.setter
121+
def pref_vector(self, value: Tensor | None) -> None:
122+
self.gramian_weighting.pref_vector = value
123+
124+
@property
125+
def scale_mode(self) -> SUPPORTED_SCALE_MODE:
126+
return self.gramian_weighting.scale_mode
127+
128+
@scale_mode.setter
129+
def scale_mode(self, value: SUPPORTED_SCALE_MODE) -> None:
130+
self.gramian_weighting.scale_mode = value
131+
110132
def __repr__(self) -> str:
111133
return (
112-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
113-
f"scale_mode={repr(self._scale_mode)})"
134+
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, "
135+
f"scale_mode={repr(self.scale_mode)})"
114136
)
115137

116138
def __str__(self) -> str:
117-
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
139+
return f"AlignedMTL{pref_vector_to_str_suffix(self.pref_vector)}"

src/torchjd/aggregation/_cagrad.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ class CAGradWeighting(GramianWeighting):
3737

3838
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
3939
super().__init__()
40-
41-
if c < 0.0:
42-
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")
43-
4440
self.c = c
4541
self.norm_eps = norm_eps
4642

@@ -73,6 +69,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
7369

7470
return weights
7571

72+
@property
73+
def c(self) -> float:
74+
return self._c
75+
76+
@c.setter
77+
def c(self, value: float) -> None:
78+
if value < 0:
79+
raise ValueError(f"c must be non-negative, but got {value}.")
80+
81+
self._c = value
82+
83+
@property
84+
def norm_eps(self) -> float:
85+
return self._norm_eps
86+
87+
@norm_eps.setter
88+
def norm_eps(self, value: float) -> None:
89+
if value < 0:
90+
raise ValueError(f"norm_eps must be non-negative, but got {value}.")
91+
92+
self._norm_eps = value
93+
7694

7795
class CAGrad(GramianWeightedAggregator):
7896
"""
@@ -94,15 +112,29 @@ class CAGrad(GramianWeightedAggregator):
94112

95113
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
96114
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
97-
self._c = c
98-
self._norm_eps = norm_eps
99115

100116
# This prevents considering the computed weights as constant w.r.t. the matrix.
101117
self.register_full_backward_pre_hook(raise_non_differentiable_error)
102118

119+
@property
120+
def c(self) -> float:
121+
return self.gramian_weighting.c
122+
123+
@c.setter
124+
def c(self, value: float) -> None:
125+
self.gramian_weighting.c = value
126+
127+
@property
128+
def norm_eps(self) -> float:
129+
return self.gramian_weighting.norm_eps
130+
131+
@norm_eps.setter
132+
def norm_eps(self, value: float) -> None:
133+
self.gramian_weighting.norm_eps = value
134+
103135
def __repr__(self) -> str:
104-
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
136+
return f"{self.__class__.__name__}(c={self.c}, norm_eps={self.norm_eps})"
105137

106138
def __str__(self) -> str:
107-
c_str = str(self._c).rstrip("0")
139+
c_str = str(self.c).rstrip("0")
108140
return f"CAGrad{c_str}"

src/torchjd/aggregation/_config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class ConFIG(Aggregator):
2929

3030
def __init__(self, pref_vector: Tensor | None = None) -> None:
3131
super().__init__()
32-
self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting())
33-
self._pref_vector = pref_vector
32+
self.pref_vector = pref_vector
3433

3534
# This prevents computing gradients that can be very wrong.
3635
self.register_full_backward_pre_hook(raise_non_differentiable_error)
@@ -46,8 +45,17 @@ def forward(self, matrix: Matrix, /) -> Tensor:
4645

4746
return length * unit_target_vector
4847

48+
@property
49+
def pref_vector(self) -> Tensor | None:
50+
return self._pref_vector
51+
52+
@pref_vector.setter
53+
def pref_vector(self, value: Tensor | None) -> None:
54+
self.weighting = pref_vector_to_weighting(value, default=SumWeighting())
55+
self._pref_vector = value
56+
4957
def __repr__(self) -> str:
50-
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
58+
return f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)})"
5159

5260
def __str__(self) -> str:
53-
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"
61+
return f"ConFIG{pref_vector_to_str_suffix(self.pref_vector)}"

src/torchjd/aggregation/_dualproj.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def __init__(
3333
solver: SUPPORTED_SOLVER = "quadprog",
3434
) -> None:
3535
super().__init__()
36-
self._pref_vector = pref_vector
37-
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
36+
self.pref_vector = pref_vector
3837
self.norm_eps = norm_eps
3938
self.reg_eps = reg_eps
4039
self.solver: SUPPORTED_SOLVER = solver
@@ -45,6 +44,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
4544
w = project_weights(u, G, self.solver)
4645
return w
4746

47+
@property
48+
def pref_vector(self) -> Tensor | None:
49+
return self._pref_vector
50+
51+
@pref_vector.setter
52+
def pref_vector(self, value: Tensor | None) -> None:
53+
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
54+
self._pref_vector = value
55+
56+
@property
57+
def norm_eps(self) -> float:
58+
return self._norm_eps
59+
60+
@norm_eps.setter
61+
def norm_eps(self, value: float) -> None:
62+
if value < 0:
63+
raise ValueError(f"norm_eps must be non-negative, but got {value}.")
64+
65+
self._norm_eps = value
66+
67+
@property
68+
def reg_eps(self) -> float:
69+
return self._reg_eps
70+
71+
@reg_eps.setter
72+
def reg_eps(self, value: float) -> None:
73+
if value < 0:
74+
raise ValueError(f"reg_eps must be non-negative, but got {value}.")
75+
76+
self._reg_eps = value
77+
4878

4979
class DualProj(GramianWeightedAggregator):
5080
r"""
@@ -72,9 +102,6 @@ def __init__(
72102
reg_eps: float = 0.0001,
73103
solver: SUPPORTED_SOLVER = "quadprog",
74104
) -> None:
75-
self._pref_vector = pref_vector
76-
self._norm_eps = norm_eps
77-
self._reg_eps = reg_eps
78105
self._solver: SUPPORTED_SOLVER = solver
79106

80107
super().__init__(
@@ -84,11 +111,35 @@ def __init__(
84111
# This prevents considering the computed weights as constant w.r.t. the matrix.
85112
self.register_full_backward_pre_hook(raise_non_differentiable_error)
86113

114+
@property
115+
def pref_vector(self) -> Tensor | None:
116+
return self.gramian_weighting.pref_vector
117+
118+
@pref_vector.setter
119+
def pref_vector(self, value: Tensor | None) -> None:
120+
self.gramian_weighting.pref_vector = value
121+
122+
@property
123+
def norm_eps(self) -> float:
124+
return self.gramian_weighting.norm_eps
125+
126+
@norm_eps.setter
127+
def norm_eps(self, value: float) -> None:
128+
self.gramian_weighting.norm_eps = value
129+
130+
@property
131+
def reg_eps(self) -> float:
132+
return self.gramian_weighting.reg_eps
133+
134+
@reg_eps.setter
135+
def reg_eps(self, value: float) -> None:
136+
self.gramian_weighting.reg_eps = value
137+
87138
def __repr__(self) -> str:
88139
return (
89-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
90-
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
140+
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
141+
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
91142
)
92143

93144
def __str__(self) -> str:
94-
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"
145+
return f"DualProj{pref_vector_to_str_suffix(self.pref_vector)}"

src/torchjd/aggregation/_graddrop.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ class GradDrop(Aggregator):
2727
"""
2828

2929
def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
30-
if leak is not None and leak.dim() != 1:
31-
raise ValueError(
32-
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
33-
f"{leak.shape}`.",
34-
)
35-
3630
super().__init__()
3731
self.f = f
3832
self.leak = leak
@@ -59,6 +53,19 @@ def forward(self, matrix: Matrix, /) -> Tensor:
5953

6054
return vector
6155

56+
@property
57+
def leak(self) -> Tensor | None:
58+
return self._leak
59+
60+
@leak.setter
61+
def leak(self, value: Tensor | None) -> None:
62+
if value is not None and value.dim() != 1:
63+
raise ValueError(
64+
f"leak must be a 1-dimensional tensor. Found leak.shape = {value.shape}.",
65+
)
66+
67+
self._leak = value
68+
6269
def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
6370
n_rows = matrix.shape[0]
6471
if self.leak is not None and n_rows != len(self.leak):

src/torchjd/aggregation/_gradvac.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,8 @@ class GradVacWeighting(GramianWeighting, Stateful):
4141

4242
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
4343
super().__init__()
44-
if not (0.0 <= beta <= 1.0):
45-
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
46-
if eps < 0.0:
47-
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")
48-
49-
self._beta = beta
50-
self._eps = eps
44+
self.beta = beta
45+
self.eps = eps
5146
self._phi_t: Tensor | None = None
5247
self._state_key: tuple[int, torch.dtype] | None = None
5348

0 commit comments

Comments
 (0)