Skip to content

Commit bcc5bd1

Browse files
committed
Standardize expand_sensor_model variable handling
1 parent 8533279 commit bcc5bd1

1 file changed

Lines changed: 42 additions & 52 deletions

File tree

sgptools/core/transformations.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# http://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -53,9 +53,9 @@ def __init__(self,
5353
self.aggregation_size = aggregation_size
5454
self.constraint_weight = constraint_weight
5555

56-
def expand(
57-
self, Xu: Union[np.ndarray,
58-
tf.Tensor]) -> Union[np.ndarray, tf.Tensor]:
56+
def expand(self,
57+
Xu: Union[np.ndarray, tf.Tensor],
58+
**kwargs: Any) -> Union[np.ndarray, tf.Tensor]:
5959
"""
6060
Applies an expansion transform to the inducing points.
6161
In this base class, it simply returns the input inducing points unchanged.
@@ -65,10 +65,12 @@ def expand(
6565
Xu (Union[np.ndarray, tf.Tensor]): The input inducing points.
6666
Shape: (m, d) where `m` is the number of inducing points
6767
and `d` is their dimensionality.
68+
**kwargs (Any): Additional keyword arguments that specific subclasses may interpret.
6869
6970
Returns:
7071
Union[np.ndarray, tf.Tensor]: The expanded inducing points.
7172
"""
73+
# Base implementation ignores any kwargs and returns Xu unchanged.
7274
return Xu
7375

7476
def aggregate(self, k: tf.Tensor) -> tf.Tensor:
@@ -178,7 +180,7 @@ def __init__(self,
178180
consecutive inducing points. `sampling_rate=2` implies
179181
only the two endpoints are used (point sensing).
180182
`sampling_rate > 2` implies continuous sensing via interpolation.
181-
Must be $\ge 2$. Defaults to 2.
183+
Must be >= 2. Defaults to 2.
182184
distance_budget (Optional[float]): The maximum allowable total path length for each robot.
183185
If None, no distance constraint is applied. Defaults to None.
184186
num_robots (int): The number of robots or agents involved in the IPP problem. Defaults to 1.
@@ -199,18 +201,6 @@ def __init__(self,
199201
200202
Raises:
201203
ValueError: If `sampling_rate` is less than 2.
202-
203-
Usage:
204-
```python
205-
# Single robot, point sensing
206-
transform_point = IPPTransform(num_robots=1, num_dim=2, sampling_rate=2)
207-
208-
# Single robot, continuous sensing
209-
transform_continuous = IPPTransform(num_robots=1, num_dim=2, sampling_rate=10)
210-
211-
# Multi-robot, continuous sensing with distance budget
212-
transform_multi_budget = IPPTransform(num_robots=2, num_dim=2, sampling_rate=5, distance_budget=50.0, constraint_weight=100.0)
213-
```
214204
"""
215205
super().__init__(**kwargs)
216206
if sampling_rate < 2:
@@ -237,7 +227,7 @@ def __init__(self,
237227
# Initialize TensorFlow Variable for fixed waypoints if provided, for online IPP.
238228
if Xu_fixed is not None:
239229
# Store number of fixed waypoints per robot
240-
self.num_fixed = Xu_fixed.shape[1]
230+
self.num_fixed = Xu_fixed.shape[1]
241231
self.Xu_fixed = tf.Variable(
242232
Xu_fixed,
243233
shape=tf.TensorShape(None),
@@ -256,7 +246,7 @@ def update_Xu_fixed(self, Xu_fixed: np.ndarray) -> None:
256246
representing the new set of fixed waypoints.
257247
"""
258248
# Store number of fixed waypoints per robot
259-
self.num_fixed = Xu_fixed.shape[1]
249+
self.num_fixed = Xu_fixed.shape[1]
260250
if self.Xu_fixed is not None:
261251
self.Xu_fixed.assign(tf.constant(Xu_fixed, dtype=default_float()))
262252
else:
@@ -267,7 +257,8 @@ def update_Xu_fixed(self, Xu_fixed: np.ndarray) -> None:
267257

268258
def expand(self,
269259
Xu: tf.Tensor,
270-
expand_sensor_model: bool = True) -> tf.Tensor:
260+
expand_sensor_model: bool = True,
261+
**kwargs: Any) -> tf.Tensor:
271262
"""
272263
Applies the expansion transform to the inducing points based on the IPP settings.
273264
This can involve:
@@ -284,6 +275,7 @@ def expand(self,
284275
only the path interpolation and fixed point handling
285276
are performed, useful for internal calculations like distance.
286277
Defaults to True.
278+
**kwargs (Any): Additional keyword arguments for future extensibility.
287279
288280
Returns:
289281
tf.Tensor: The expanded inducing points, ready for kernel computations.
@@ -403,8 +395,6 @@ def distance(self, Xu: tf.Tensor) -> tf.Tensor:
403395
# For point/continuous sensing without a special FoV model:
404396
# Calculate Euclidean distance between consecutive waypoints.
405397
# Assuming first two dimensions are (x,y) for distance calculation.
406-
# `Xu_reshaped[:, 1:, :2]` are points from the second to last.
407-
# `Xu_reshaped[:, :-1, :2]` are points from the first to second to last.
408398
segment_distances = tf.norm(Xu_reshaped[:, 1:, :2] -
409399
Xu_reshaped[:, :-1, :2],
410400
axis=-1)
@@ -437,12 +427,6 @@ def __init__(self,
437427
This averages covariances from the FoV points to reduce computational cost.
438428
Defaults to False.
439429
**kwargs (Any): Additional keyword arguments passed to the base `Transform` constructor.
440-
441-
Usage:
442-
```python
443-
# Create a square FoV of side length 10.0, approximated by a 5x5 grid of points
444-
square_fov_transform = SquareTransform(length=10.0, pts_per_side=5, aggregate_fov=True)
445-
```
446430
"""
447431
super().__init__(**kwargs)
448432
self.side_length = side_length
@@ -471,7 +455,10 @@ def enable_aggregation(self, size: Optional[int] = None) -> None:
471455
else:
472456
self.aggregation_size = size
473457

474-
def expand(self, Xu: tf.Tensor) -> tf.Tensor:
458+
def expand(self,
459+
Xu: tf.Tensor,
460+
expand_sensor_model: bool = True,
461+
**kwargs: Any) -> tf.Tensor:
475462
"""
476463
Applies the expansion transformation to the inducing points, modeling a square FoV.
477464
Each input inducing point, which includes position (x, y) and orientation (theta),
@@ -481,14 +468,20 @@ def expand(self, Xu: tf.Tensor) -> tf.Tensor:
481468
Xu (tf.Tensor): Inducing points in the position and orientation space.
482469
Shape: (m, 3) where `m` is the number of inducing points,
483470
and `3` corresponds to (x, y, angle in radians).
484-
471+
expand_sensor_model (bool): Controls whether the square FoV expansion is applied.
472+
Defaults to True.
473+
**kwargs (Any): Additional keyword arguments.
474+
485475
Returns:
486476
tf.Tensor: The expanded inducing points in 2D input space (x,y).
487477
Shape: (m * pts_per_side * pts_per_side, 2).
488478
`m` is the number of original inducing points.
489479
`pts_per_side * pts_per_side` is the number of points each inducing
490480
point is mapped to in order to form the FoV.
491481
"""
482+
if not expand_sensor_model:
483+
return Xu
484+
492485
# Split Xu into x, y coordinates and orientation (theta)
493486
x_coords, y_coords, angles = tf.split(Xu, num_or_size_splits=3, axis=1)
494487
x = tf.reshape(x_coords, [
@@ -532,12 +525,7 @@ def expand(self, Xu: tf.Tensor) -> tf.Tensor:
532525
points.append(
533526
tf.linspace(line_starts, line_ends, self.pts_per_side, axis=1))
534527

535-
# Concatenate all generated line segments.
536-
# `tf.concat` will stack them along a new axis, forming (num_lines, m, pts_per_side, 2)
537-
xy = tf.concat(
538-
points, axis=1
539-
) # (m, pts_per_side * pts_per_side, 2) after the transpose in the original code.
540-
528+
xy = tf.concat(points, axis=1)
541529
xy = tf.reshape(xy, (-1, 2))
542530
return xy
543531

@@ -555,9 +543,8 @@ def distance(self, Xu: tf.Tensor) -> tf.Tensor:
555543
Returns:
556544
tf.Tensor: A scalar tensor representing the total path length.
557545
"""
558-
# Reshape to (number_of_points, 3) and take only the (x,y) coordinates
559-
Xu_xy = tf.reshape(
560-
Xu, (-1, self.num_dim))[:, :2] # Assuming num_dim is 3 (x,y,angle)
546+
# Assuming dimension 3: (x, y, angle)
547+
Xu_xy = tf.reshape(Xu, (-1, 3))[:, :2]
561548

562549
if Xu_xy.shape[0] < 2:
563550
return tf.constant(0.0, dtype=default_float())
@@ -585,17 +572,11 @@ def __init__(self,
585572
586573
Args:
587574
pts_per_side (int): The number of points to sample along each side of the square FoV.
588-
A `pts_per_side` of 3 will create a 3x3 grid of 9 points to approximate the FoV.
575+
A `pts_per_side` of 3 will create a 3x3 grid of 9 points to approximate the FoV.
589576
aggregate_fov (bool): If True, aggregation will be enabled for the expanded FoV points.
590577
This averages covariances from the FoV points to reduce computational cost.
591578
Defaults to False.
592579
**kwargs (Any): Additional keyword arguments passed to the base `Transform` constructor.
593-
594-
Usage:
595-
```python
596-
# Create a height-dependent square FoV approximated by a 7x7 grid
597-
square_height_fov_transform = SquareHeightTransform(pts_per_side=7, aggregate_fov=True)
598-
```
599580
"""
600581
super().__init__(**kwargs)
601582
self.pts_per_side = pts_per_side
@@ -619,20 +600,29 @@ def enable_aggregation(self, size: Optional[int] = None) -> None:
619600
else:
620601
self.aggregation_size = size
621602

622-
def expand(self, Xu):
603+
def expand(self,
604+
Xu: tf.Tensor,
605+
expand_sensor_model: bool = True,
606+
**kwargs: Any) -> tf.Tensor:
623607
"""
624-
Applies the expansion transform to the inducing points
608+
Applies the expansion transform to the inducing points.
625609
626610
Args:
627-
Xu (ndarray): (m, 3); Inducing points in the 3D position space.
611+
Xu (tf.Tensor): (m, 3); Inducing points in the 3D position space.
628612
`m` is the number of inducing points,
629613
`3` is the dimension of the space (x, y, z)
630-
614+
expand_sensor_model (bool): Controls whether the height-dependent FoV expansion is applied.
615+
Defaults to True.
616+
**kwargs (Any): Additional keyword arguments.
617+
631618
Returns:
632-
Xu (ndarray): (mp, 2); Inducing points in input space.
633-
`p` is the number of points each inducing point is mapped
619+
tf.Tensor: (mp, 2); Inducing points in input space.
620+
`p` is the number of points each inducing point is mapped
634621
to in order to form the FoV.
635622
"""
623+
if not expand_sensor_model:
624+
return Xu
625+
636626
x, y, h = tf.split(Xu, num_or_size_splits=3, axis=1)
637627
x = tf.reshape(x, [
638628
-1,

0 commit comments

Comments
 (0)