@@ -33,8 +33,7 @@ def __init__(
3333 solver : SUPPORTED_SOLVER = "quadprog" ,
3434 ) -> None :
3535 super ().__init__ ()
36- self ._pref_vector = pref_vector
37- self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
36+ self .pref_vector = pref_vector
3837 self .norm_eps = norm_eps
3938 self .reg_eps = reg_eps
4039 self .solver : SUPPORTED_SOLVER = solver
@@ -45,6 +44,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
4544 w = project_weights (u , G , self .solver )
4645 return w
4746
47+ @property
48+ def pref_vector (self ) -> Tensor | None :
49+ return self ._pref_vector
50+
51+ @pref_vector .setter
52+ def pref_vector (self , value : Tensor | None ) -> None :
53+ self .weighting = pref_vector_to_weighting (value , default = MeanWeighting ())
54+ self ._pref_vector = value
55+
56+ @property
57+ def norm_eps (self ) -> float :
58+ return self ._norm_eps
59+
60+ @norm_eps .setter
61+ def norm_eps (self , value : float ) -> None :
62+ if value < 0 :
63+ raise ValueError (f"norm_eps must be non-negative, but got { value } ." )
64+
65+ self ._norm_eps = value
66+
67+ @property
68+ def reg_eps (self ) -> float :
69+ return self ._reg_eps
70+
71+ @reg_eps .setter
72+ def reg_eps (self , value : float ) -> None :
73+ if value < 0 :
74+ raise ValueError (f"reg_eps must be non-negative, but got { value } ." )
75+
76+ self ._reg_eps = value
77+
4878
4979class DualProj (GramianWeightedAggregator ):
5080 r"""
@@ -72,9 +102,6 @@ def __init__(
72102 reg_eps : float = 0.0001 ,
73103 solver : SUPPORTED_SOLVER = "quadprog" ,
74104 ) -> None :
75- self ._pref_vector = pref_vector
76- self ._norm_eps = norm_eps
77- self ._reg_eps = reg_eps
78105 self ._solver : SUPPORTED_SOLVER = solver
79106
80107 super ().__init__ (
@@ -84,11 +111,35 @@ def __init__(
84111 # This prevents considering the computed weights as constant w.r.t. the matrix.
85112 self .register_full_backward_pre_hook (raise_non_differentiable_error )
86113
114+ @property
115+ def pref_vector (self ) -> Tensor | None :
116+ return self .gramian_weighting .pref_vector
117+
118+ @pref_vector .setter
119+ def pref_vector (self , value : Tensor | None ) -> None :
120+ self .gramian_weighting .pref_vector = value
121+
122+ @property
123+ def norm_eps (self ) -> float :
124+ return self .gramian_weighting .norm_eps
125+
126+ @norm_eps .setter
127+ def norm_eps (self , value : float ) -> None :
128+ self .gramian_weighting .norm_eps = value
129+
130+ @property
131+ def reg_eps (self ) -> float :
132+ return self .gramian_weighting .reg_eps
133+
134+ @reg_eps .setter
135+ def reg_eps (self , value : float ) -> None :
136+ self .gramian_weighting .reg_eps = value
137+
87138 def __repr__ (self ) -> str :
88139 return (
89- f"{ self .__class__ .__name__ } (pref_vector={ repr (self ._pref_vector )} , norm_eps="
90- f"{ self ._norm_eps } , reg_eps={ self ._reg_eps } , solver={ repr (self ._solver )} )"
140+ f"{ self .__class__ .__name__ } (pref_vector={ repr (self .pref_vector )} , norm_eps="
141+ f"{ self .norm_eps } , reg_eps={ self .reg_eps } , solver={ repr (self ._solver )} )"
91142 )
92143
93144 def __str__ (self ) -> str :
94- return f"DualProj{ pref_vector_to_str_suffix (self ._pref_vector )} "
145+ return f"DualProj{ pref_vector_to_str_suffix (self .pref_vector )} "
0 commit comments