44# you may not use this file except in compliance with the License.
55# You may obtain a copy of the License at
66#
7- # http://www.apache.org/licenses/LICENSE-2.0
7+ # http://www.apache.org/licenses/LICENSE-2.0
88#
99# Unless required by applicable law or agreed to in writing, software
1010# distributed under the License is distributed on an "AS IS" BASIS,
@@ -53,9 +53,9 @@ def __init__(self,
5353 self .aggregation_size = aggregation_size
5454 self .constraint_weight = constraint_weight
5555
56- def expand (
57- self , Xu : Union [np .ndarray ,
58- tf . Tensor ] ) -> Union [np .ndarray , tf .Tensor ]:
56+ def expand (self ,
57+ Xu : Union [np .ndarray , tf . Tensor ] ,
58+ ** kwargs : Any ) -> Union [np .ndarray , tf .Tensor ]:
5959 """
6060 Applies an expansion transform to the inducing points.
6161 In this base class, it simply returns the input inducing points unchanged.
@@ -65,10 +65,12 @@ def expand(
6565 Xu (Union[np.ndarray, tf.Tensor]): The input inducing points.
6666 Shape: (m, d) where `m` is the number of inducing points
6767 and `d` is their dimensionality.
68+ **kwargs (Any): Additional keyword arguments that specific subclasses may interpret.
6869
6970 Returns:
7071 Union[np.ndarray, tf.Tensor]: The expanded inducing points.
7172 """
73+ # Base implementation ignores any kwargs and returns Xu unchanged.
7274 return Xu
7375
7476 def aggregate (self , k : tf .Tensor ) -> tf .Tensor :
@@ -178,7 +180,7 @@ def __init__(self,
178180 consecutive inducing points. `sampling_rate=2` implies
179181 only the two endpoints are used (point sensing).
180182 `sampling_rate > 2` implies continuous sensing via interpolation.
181- Must be $\ge 2$ . Defaults to 2.
183+ Must be >= 2 . Defaults to 2.
182184 distance_budget (Optional[float]): The maximum allowable total path length for each robot.
183185 If None, no distance constraint is applied. Defaults to None.
184186 num_robots (int): The number of robots or agents involved in the IPP problem. Defaults to 1.
@@ -199,18 +201,6 @@ def __init__(self,
199201
200202 Raises:
201203 ValueError: If `sampling_rate` is less than 2.
202-
203- Usage:
204- ```python
205- # Single robot, point sensing
206- transform_point = IPPTransform(num_robots=1, num_dim=2, sampling_rate=2)
207-
208- # Single robot, continuous sensing
209- transform_continuous = IPPTransform(num_robots=1, num_dim=2, sampling_rate=10)
210-
211- # Multi-robot, continuous sensing with distance budget
212- transform_multi_budget = IPPTransform(num_robots=2, num_dim=2, sampling_rate=5, distance_budget=50.0, constraint_weight=100.0)
213- ```
214204 """
215205 super ().__init__ (** kwargs )
216206 if sampling_rate < 2 :
@@ -237,7 +227,7 @@ def __init__(self,
237227 # Initialize TensorFlow Variable for fixed waypoints if provided, for online IPP.
238228 if Xu_fixed is not None :
239229 # Store number of fixed waypoints per robot
240- self .num_fixed = Xu_fixed .shape [1 ]
230+ self .num_fixed = Xu_fixed .shape [1 ]
241231 self .Xu_fixed = tf .Variable (
242232 Xu_fixed ,
243233 shape = tf .TensorShape (None ),
@@ -256,7 +246,7 @@ def update_Xu_fixed(self, Xu_fixed: np.ndarray) -> None:
256246 representing the new set of fixed waypoints.
257247 """
258248 # Store number of fixed waypoints per robot
259- self .num_fixed = Xu_fixed .shape [1 ]
249+ self .num_fixed = Xu_fixed .shape [1 ]
260250 if self .Xu_fixed is not None :
261251 self .Xu_fixed .assign (tf .constant (Xu_fixed , dtype = default_float ()))
262252 else :
@@ -267,7 +257,8 @@ def update_Xu_fixed(self, Xu_fixed: np.ndarray) -> None:
267257
268258 def expand (self ,
269259 Xu : tf .Tensor ,
270- expand_sensor_model : bool = True ) -> tf .Tensor :
260+ expand_sensor_model : bool = True ,
261+ ** kwargs : Any ) -> tf .Tensor :
271262 """
272263 Applies the expansion transform to the inducing points based on the IPP settings.
273264 This can involve:
@@ -284,6 +275,7 @@ def expand(self,
284275 only the path interpolation and fixed point handling
285276 are performed, useful for internal calculations like distance.
286277 Defaults to True.
278+ **kwargs (Any): Additional keyword arguments for future extensibility.
287279
288280 Returns:
289281 tf.Tensor: The expanded inducing points, ready for kernel computations.
@@ -403,8 +395,6 @@ def distance(self, Xu: tf.Tensor) -> tf.Tensor:
403395 # For point/continuous sensing without a special FoV model:
404396 # Calculate Euclidean distance between consecutive waypoints.
405397 # Assuming first two dimensions are (x,y) for distance calculation.
406- # `Xu_reshaped[:, 1:, :2]` are points from the second to last.
407- # `Xu_reshaped[:, :-1, :2]` are points from the first to second to last.
408398 segment_distances = tf .norm (Xu_reshaped [:, 1 :, :2 ] -
409399 Xu_reshaped [:, :- 1 , :2 ],
410400 axis = - 1 )
@@ -437,12 +427,6 @@ def __init__(self,
437427 This averages covariances from the FoV points to reduce computational cost.
438428 Defaults to False.
439429 **kwargs (Any): Additional keyword arguments passed to the base `Transform` constructor.
440-
441- Usage:
442- ```python
443- # Create a square FoV of side length 10.0, approximated by a 5x5 grid of points
444- square_fov_transform = SquareTransform(length=10.0, pts_per_side=5, aggregate_fov=True)
445- ```
446430 """
447431 super ().__init__ (** kwargs )
448432 self .side_length = side_length
@@ -471,7 +455,10 @@ def enable_aggregation(self, size: Optional[int] = None) -> None:
471455 else :
472456 self .aggregation_size = size
473457
474- def expand (self , Xu : tf .Tensor ) -> tf .Tensor :
458+ def expand (self ,
459+ Xu : tf .Tensor ,
460+ expand_sensor_model : bool = True ,
461+ ** kwargs : Any ) -> tf .Tensor :
475462 """
476463 Applies the expansion transformation to the inducing points, modeling a square FoV.
477464 Each input inducing point, which includes position (x, y) and orientation (theta),
@@ -481,14 +468,20 @@ def expand(self, Xu: tf.Tensor) -> tf.Tensor:
481468 Xu (tf.Tensor): Inducing points in the position and orientation space.
482469 Shape: (m, 3) where `m` is the number of inducing points,
483470 and `3` corresponds to (x, y, angle in radians).
484-
471+ expand_sensor_model (bool): Controls whether the square FoV expansion is applied.
472+ Defaults to True.
473+ **kwargs (Any): Additional keyword arguments.
474+
485475 Returns:
486476 tf.Tensor: The expanded inducing points in 2D input space (x,y).
487477 Shape: (m * pts_per_side * pts_per_side, 2).
488478 `m` is the number of original inducing points.
489479 `pts_per_side * pts_per_side` is the number of points each inducing
490480 point is mapped to in order to form the FoV.
491481 """
482+ if not expand_sensor_model :
483+ return Xu
484+
492485 # Split Xu into x, y coordinates and orientation (theta)
493486 x_coords , y_coords , angles = tf .split (Xu , num_or_size_splits = 3 , axis = 1 )
494487 x = tf .reshape (x_coords , [
@@ -532,12 +525,7 @@ def expand(self, Xu: tf.Tensor) -> tf.Tensor:
532525 points .append (
533526 tf .linspace (line_starts , line_ends , self .pts_per_side , axis = 1 ))
534527
535- # Concatenate all generated line segments.
536- # `tf.concat` will stack them along a new axis, forming (num_lines, m, pts_per_side, 2)
537- xy = tf .concat (
538- points , axis = 1
539- ) # (m, pts_per_side * pts_per_side, 2) after the transpose in the original code.
540-
528+ xy = tf .concat (points , axis = 1 )
541529 xy = tf .reshape (xy , (- 1 , 2 ))
542530 return xy
543531
@@ -555,9 +543,8 @@ def distance(self, Xu: tf.Tensor) -> tf.Tensor:
555543 Returns:
556544 tf.Tensor: A scalar tensor representing the total path length.
557545 """
558- # Reshape to (number_of_points, 3) and take only the (x,y) coordinates
559- Xu_xy = tf .reshape (
560- Xu , (- 1 , self .num_dim ))[:, :2 ] # Assuming num_dim is 3 (x,y,angle)
546+ # Assuming dimension 3: (x, y, angle)
547+ Xu_xy = tf .reshape (Xu , (- 1 , 3 ))[:, :2 ]
561548
562549 if Xu_xy .shape [0 ] < 2 :
563550 return tf .constant (0.0 , dtype = default_float ())
@@ -585,17 +572,11 @@ def __init__(self,
585572
586573 Args:
587574 pts_per_side (int): The number of points to sample along each side of the square FoV.
588- A `pts_per_side` of 3 will create a 3x3 grid of 9 points to approximate the FoV.
575+ A `pts_per_side` of 3 will create a 3x3 grid of 9 points to approximate the FoV.
589576 aggregate_fov (bool): If True, aggregation will be enabled for the expanded FoV points.
590577 This averages covariances from the FoV points to reduce computational cost.
591578 Defaults to False.
592579 **kwargs (Any): Additional keyword arguments passed to the base `Transform` constructor.
593-
594- Usage:
595- ```python
596- # Create a height-dependent square FoV approximated by a 7x7 grid
597- square_height_fov_transform = SquareHeightTransform(pts_per_side=7, aggregate_fov=True)
598- ```
599580 """
600581 super ().__init__ (** kwargs )
601582 self .pts_per_side = pts_per_side
@@ -619,20 +600,29 @@ def enable_aggregation(self, size: Optional[int] = None) -> None:
619600 else :
620601 self .aggregation_size = size
621602
622- def expand (self , Xu ):
603+ def expand (self ,
604+ Xu : tf .Tensor ,
605+ expand_sensor_model : bool = True ,
606+ ** kwargs : Any ) -> tf .Tensor :
623607 """
624- Applies the expansion transform to the inducing points
608+ Applies the expansion transform to the inducing points.
625609
626610 Args:
627- Xu (ndarray ): (m, 3); Inducing points in the 3D position space.
611+ Xu (tf.Tensor ): (m, 3); Inducing points in the 3D position space.
628612 `m` is the number of inducing points,
629613 `3` is the dimension of the space (x, y, z)
630-
614+ expand_sensor_model (bool): Controls whether the height-dependent FoV expansion is applied.
615+ Defaults to True.
616+ **kwargs (Any): Additional keyword arguments.
617+
631618 Returns:
632- Xu (ndarray) : (mp, 2); Inducing points in input space.
633- `p` is the number of points each inducing point is mapped
619+ tf.Tensor : (mp, 2); Inducing points in input space.
620+ `p` is the number of points each inducing point is mapped
634621 to in order to form the FoV.
635622 """
623+ if not expand_sensor_model :
624+ return Xu
625+
636626 x , y , h = tf .split (Xu , num_or_size_splits = 3 , axis = 1 )
637627 x = tf .reshape (x , [
638628 - 1 ,
0 commit comments