From 872cb0dc2c4d3643e9a2cc8e75a71ea36b592fdf Mon Sep 17 00:00:00 2001 From: Umais Khan Date: Fri, 6 Mar 2026 22:11:29 -0800 Subject: [PATCH 01/21] feat: Add compute_polarization for collective behavior analysis --- movement/kinematics/__init__.py | 2 + movement/kinematics/collective.py | 216 +++++++++++++++ .../test_kinematics/test_collective.py | 259 ++++++++++++++++++ 3 files changed, 477 insertions(+) create mode 100644 movement/kinematics/collective.py create mode 100644 tests/test_unit/test_kinematics/test_collective.py diff --git a/movement/kinematics/__init__.py b/movement/kinematics/__init__.py index 7216a367d..9d03d5651 100644 --- a/movement/kinematics/__init__.py +++ b/movement/kinematics/__init__.py @@ -1,5 +1,6 @@ """Compute variables derived from ``position`` data.""" +from movement.kinematics.collective import compute_polarization from movement.kinematics.distances import compute_pairwise_distances from movement.kinematics.kinematics import ( compute_acceleration, @@ -32,4 +33,5 @@ "compute_head_direction_vector", "compute_forward_vector_angle", "compute_kinetic_energy", + "compute_polarization", ] diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py new file mode 100644 index 000000000..01f70ed7b --- /dev/null +++ b/movement/kinematics/collective.py @@ -0,0 +1,216 @@ +"""Compute collective behavior metrics for multi-individual tracking data.""" + +from collections.abc import Hashable + +import numpy as np +import xarray as xr + +from movement.kinematics.kinematics import compute_velocity +from movement.utils.logging import logger +from movement.utils.vector import compute_norm, convert_to_unit +from movement.validators.arrays import validate_dims_coords + + +def compute_polarization( + data: xr.DataArray, + heading_keypoints: tuple[Hashable, Hashable] | None = None, +) -> xr.DataArray: + r"""Compute the polarization (group alignment) of multiple individuals. + + Polarization measures how aligned the heading directions of individuals + are. A value of 1 indicates all individuals are heading in the same + direction, while a value near 0 indicates random orientations. + + The polarization is computed as: + + .. math:: \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{v}_i \right\| + + where :math:`\hat{v}_i` is the unit heading vector for individual + :math:`i`, and :math:`N` is the number of individuals. + + Parameters + ---------- + data : xarray.DataArray + The input data representing position. Must contain ``time``, + ``space``, and ``individuals`` as dimensions. The ``keypoints`` + dimension is required only if ``heading_keypoints`` is provided. + heading_keypoints : tuple of Hashable, optional + A tuple of two keypoint names ``(origin, target)`` used to + compute the heading direction as the vector from origin to + target (e.g., ``("neck", "nose")`` or ``("tail", "head")``). + If None, heading is inferred from the velocity of the first + available keypoint. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the polarization value at each + time point, with dimensions ``(time,)``. Values range from 0 + (random orientations) to 1 (perfectly aligned). + + Notes + ----- + If ``heading_keypoints`` is provided, the heading for each individual + is computed as the unit vector from the origin to the target + keypoint. If not provided, heading is inferred from the instantaneous + velocity direction. + + Frames where an individual has missing data (NaN) are handled by + excluding that individual from the polarization calculation for that + frame. + + Examples + -------- + Compute polarization using two keypoints to define heading: + + >>> polarization = compute_polarization( + ... ds.position, + ... heading_keypoints=("neck", "nose"), + ... ) + + Compute polarization using velocity-inferred heading: + + >>> polarization = compute_polarization(ds.position) + + See Also + -------- + movement.kinematics.compute_velocity : Compute velocity from position. + + """ + # Validate input data + _validate_type_data_array(data) + validate_dims_coords( + data, + { + "time": [], + "space": [], + "individuals": [], + }, + ) + + # Compute heading vectors for all individuals + if heading_keypoints is not None: + heading_vectors = _compute_heading_from_keypoints( + data, heading_keypoints + ) + else: + heading_vectors = _compute_heading_from_velocity(data) + + # Convert to unit vectors + unit_headings = convert_to_unit(heading_vectors) + + # Sum unit vectors across individuals + # Use nansum to handle missing data + vector_sum = unit_headings.sum(dim="individuals", skipna=True) + + # Count valid (non-NaN) individuals per time point + # A heading is valid if both x and y are not NaN + valid_mask = ~unit_headings.isnull().any(dim="space") + n_valid = valid_mask.sum(dim="individuals") + + # Compute magnitude of the sum + sum_magnitude = compute_norm(vector_sum) + + # Normalize by number of valid individuals + # Avoid division by zero + polarization = xr.where(n_valid > 0, sum_magnitude / n_valid, np.nan) + + polarization.name = "polarization" + return polarization + + +def _compute_heading_from_keypoints( + data: xr.DataArray, + heading_keypoints: tuple[Hashable, Hashable], +) -> xr.DataArray: + """Compute heading vectors from two keypoints (origin to target). + + Parameters + ---------- + data : xarray.DataArray + Position data with ``keypoints`` dimension. + heading_keypoints : tuple of Hashable + A tuple of ``(origin, target)`` keypoint names. The heading + vector points from origin toward target. + + Returns + ------- + xarray.DataArray + Heading vectors with dimensions ``(time, space, individuals)``. + + """ + origin, target = heading_keypoints + + # Validate keypoints are different + if origin == target: + raise logger.error( + ValueError("The origin and target keypoints may not be identical.") + ) + + # Validate keypoints exist + validate_dims_coords( + data, + {"keypoints": [origin, target]}, + ) + + # Compute heading as vector from origin to target + heading = data.sel(keypoints=target, drop=True) - data.sel( + keypoints=origin, drop=True + ) + + return heading + + +def _compute_heading_from_velocity(data: xr.DataArray) -> xr.DataArray: + """Compute heading vectors from velocity (displacement direction). + + Uses the first available keypoint if multiple are present. + + Parameters + ---------- + data : xarray.DataArray + Position data with ``time`` dimension. + + Returns + ------- + xarray.DataArray + Heading vectors based on velocity direction. + + """ + # If keypoints dimension exists, use first keypoint + if "keypoints" in data.dims: + first_keypoint = data.keypoints.values[0] + position = data.sel(keypoints=first_keypoint, drop=True) + logger.info( + f"Using keypoint '{first_keypoint}' for velocity-based heading." + ) + else: + position = data + + # Compute velocity as heading direction + velocity = compute_velocity(position) + + return velocity + + +def _validate_type_data_array(data: xr.DataArray) -> None: + """Validate the input data is an xarray DataArray. + + Parameters + ---------- + data : xarray.DataArray + The input data to validate. + + Raises + ------ + TypeError + If the input data is not an xarray DataArray. + + """ + if not isinstance(data, xr.DataArray): + raise logger.error( + TypeError( + "Input data must be an xarray.DataArray, " + f"but got {type(data)}." + ) + ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py new file mode 100644 index 000000000..3cf8aa258 --- /dev/null +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -0,0 +1,259 @@ +"""Tests for the collective behavior metrics module.""" + +import numpy as np +import pytest +import xarray as xr + +from movement import kinematics + + +@pytest.fixture +def position_data_aligned_individuals(): + """Return position data for 3 individuals all moving in the same direction. + + All individuals move along the positive x-axis at every time step, + so polarization should be 1.0. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2"] + keypoints = ["centroid"] + space = ["x", "y"] + + # All individuals move in +x direction + # Shape: (time=4, space=2, keypoints=1, individuals=3) + # x-coords: all increase by 1 each time step + # y-coords: all stay at 0 + data = np.array( + [ + # time 0: x=[0,1,2], y=[0,0,0] + [[[0, 1, 2]], [[0, 0, 0]]], + # time 1: x=[1,2,3], y=[0,0,0] + [[[1, 2, 3]], [[0, 0, 0]]], + # time 2: x=[2,3,4], y=[0,0,0] + [[[2, 3, 4]], [[0, 0, 0]]], + # time 3: x=[3,4,5], y=[0,0,0] + [[[3, 4, 5]], [[0, 0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_opposite_individuals(): + """Return position data for 2 individuals moving in opposite directions. + + One moves in +x, the other in -x, so polarization should be 0.0. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0 moves in +x, id_1 moves in -x + # Shape: (time=4, space=2, keypoints=1, individuals=2) + data = np.array( + [ + # time 0: x=[0,10], y=[0,0] + [[[0, 10]], [[0, 0]]], + # time 1: x=[1,9], y=[0,0] + [[[1, 9]], [[0, 0]]], + # time 2: x=[2,8], y=[0,0] + [[[2, 8]], [[0, 0]]], + # time 3: x=[3,7], y=[0,0] + [[[3, 7]], [[0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_with_keypoints(): + """Return position data with origin/target keypoints for heading. + + Two individuals, both facing the same direction (positive x). + Heading is computed as tail -> nose (origin -> target). + """ + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["nose", "tail"] + space = ["x", "y"] + + # Both individuals facing +x (nose ahead of tail in x) + # Shape: (time=3, space=2, keypoints=2, individuals=2) + # For each individual: nose is at higher x than tail + data = np.array( + [ + # time 0: nose_x=[2,5], nose_y=[0,1], tail_x=[0,3], tail_y=[0,1] + [[[2, 5], [0, 3]], [[0, 1], [0, 1]]], + # time 1 + [[[3, 6], [1, 4]], [[0, 1], [0, 1]]], + # time 2 + [[[4, 7], [2, 5]], [[0, 1], [0, 1]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_with_nan(): + """Return position data with NaN values for one individual at one time.""" + time = [0, 1, 2] + individuals = ["id_0", "id_1", "id_2"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Shape: (time=3, space=2, keypoints=1, individuals=3) + # id_1 has NaN at time 1 + data = np.array( + [ + # time 0: x=[0,1,2], y=[0,0,0] - all valid + [[[0, 1, 2]], [[0, 0, 0]]], + # time 1: x=[1,nan,3], y=[0,nan,0] - id_1 is NaN + [[[1, np.nan, 3]], [[0, np.nan, 0]]], + # time 2: x=[2,3,4], y=[0,0,0] - all valid + [[[2, 3, 4]], [[0, 0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +class TestComputePolarization: + """Test suite for the compute_polarization function.""" + + def test_polarization_aligned(self, position_data_aligned_individuals): + """Test polarization is 1.0 when all move same direction.""" + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + assert isinstance(polarization, xr.DataArray) + assert polarization.name == "polarization" + assert "time" in polarization.dims + assert "individuals" not in polarization.dims + assert "space" not in polarization.dims + + # All moving in same direction -> polarization should be ~1.0 + # (Skip first time point since velocity is computed via diff) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_polarization_opposite(self, position_data_opposite_individuals): + """Test polarization is 0.0 when individuals move opposite.""" + polarization = kinematics.compute_polarization( + position_data_opposite_individuals + ) + + assert isinstance(polarization, xr.DataArray) + assert polarization.name == "polarization" + + # Opposite directions -> polarization should be ~0.0 + # (Skip first time point since velocity is computed via diff) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_polarization_with_keypoints(self, position_data_with_keypoints): + """Test polarization using keypoint-based heading.""" + polarization = kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("tail", "nose"), # origin -> target + ) + + assert isinstance(polarization, xr.DataArray) + assert polarization.name == "polarization" + + # Both facing same direction -> polarization should be 1.0 + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + def test_polarization_handles_nan(self, position_data_with_nan): + """Test that NaN values are handled correctly.""" + polarization = kinematics.compute_polarization(position_data_with_nan) + + assert isinstance(polarization, xr.DataArray) + # Should compute polarization even with missing data + # The frame with NaN should exclude that individual from calculation + assert not np.all(np.isnan(polarization.values)) + + def test_polarization_range(self, position_data_aligned_individuals): + """Test that polarization values are in [0, 1] range.""" + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + # Exclude NaN values from range check + valid_values = polarization.values[~np.isnan(polarization.values)] + assert np.all(valid_values >= 0.0) + assert np.all(valid_values <= 1.0) + + def test_invalid_input_type(self, position_data_aligned_individuals): + """Test that non-DataArray input raises TypeError.""" + with pytest.raises(TypeError, match="must be an xarray.DataArray"): + kinematics.compute_polarization( + position_data_aligned_individuals.values + ) + + def test_missing_dimensions(self, position_data_aligned_individuals): + """Test that missing required dimensions raises ValueError.""" + # Drop individuals dimension + data_no_individuals = position_data_aligned_individuals.sel( + individuals="id_0", drop=True + ) + with pytest.raises(ValueError, match="individuals"): + kinematics.compute_polarization(data_no_individuals) + + def test_invalid_keypoints(self, position_data_with_keypoints): + """Test that invalid keypoint names raise ValueError.""" + with pytest.raises(ValueError, match="nonexistent"): + kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("nose", "nonexistent"), + ) + + def test_identical_keypoints(self, position_data_with_keypoints): + """Test that identical origin and target keypoints raise ValueError.""" + with pytest.raises(ValueError, match="may not be identical"): + kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("nose", "nose"), + ) From fc2bab509ee39a7941f5836afe0f09724e09958d Mon Sep 17 00:00:00 2001 From: khan-u Date: Mon, 16 Mar 2026 22:10:52 -0700 Subject: [PATCH 02/21] test(kinematics): add polarization edge cases and clarify first-frame behavior --- .../test_kinematics/test_collective.py | 912 ++++++++++++++++++ 1 file changed, 912 insertions(+) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 3cf8aa258..470f5c79d 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -161,6 +161,395 @@ def position_data_with_nan(): ) +@pytest.fixture +def position_data_perpendicular(): + """Return position data for 4 individuals moving in perpendicular directions. + + Each individual moves in one of the 4 cardinal directions (+x, -x, +y, -y). + The sum of unit vectors is zero, so polarization should be 0.0. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2", "id_3"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0: +x, id_1: -x, id_2: +y, id_3: -y + # Shape: (time=4, space=2, keypoints=1, individuals=4) + data = np.array( + [ + # time 0 + [[[0, 10, 0, 0]], [[0, 0, 0, 10]]], + # time 1: +x moves right, -x moves left, +y moves up, -y moves down + [[[1, 9, 0, 0]], [[0, 0, 1, 9]]], + # time 2 + [[[2, 8, 0, 0]], [[0, 0, 2, 8]]], + # time 3 + [[[3, 7, 0, 0]], [[0, 0, 3, 7]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_partial_alignment(): + """Return position data for 3 individuals with partial alignment. + + Two individuals move in +x, one moves in +y. + Expected polarization: |[2,1]|/3 = sqrt(5)/3 ≈ 0.745 + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0 and id_1 move in +x, id_2 moves in +y + # Shape: (time=4, space=2, keypoints=1, individuals=3) + data = np.array( + [ + # time 0 + [[[0, 5, 0]], [[0, 0, 0]]], + # time 1 + [[[1, 6, 0]], [[0, 0, 1]]], + # time 2 + [[[2, 7, 0]], [[0, 0, 2]]], + # time 3 + [[[3, 8, 0]], [[0, 0, 3]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_single_individual(): + """Return position data for a single individual. + + In this synthetic dataset, polarization is 1.0 whenever a valid heading + can be computed. First-frame behavior in velocity mode depends on boundary + differencing. + """ + time = [0, 1, 2, 3] + individuals = ["id_0"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Single individual moving in +x direction + data = np.array( + [ + [[[0]], [[0]]], + [[[1]], [[0]]], + [[[2]], [[0]]], + [[[3]], [[0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_all_nan_frame(): + """Return position data with one frame where all individuals are NaN.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # All individuals have NaN at time 2 + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 6]], [[0, 0]]], + [[[np.nan, np.nan]], [[np.nan, np.nan]]], # all NaN + [[[3, 8]], [[0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_stationary(): + """Return position data where individuals are stationary (zero velocity).""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both individuals stay at the same position + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[0, 5]], [[0, 0]]], + [[[0, 5]], [[0, 0]]], + [[[0, 5]], [[0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_large_n(): + """Return position data with many individuals (N=50) all aligned.""" + time = [0, 1, 2] + n_individuals = 50 + individuals = [f"id_{i}" for i in range(n_individuals)] + keypoints = ["centroid"] + space = ["x", "y"] + + # All individuals move in +x direction + # Shape: (time=3, space=2, keypoints=1, individuals=50) + x_coords = np.arange(n_individuals, dtype=float) + data = np.array( + [ + [[x_coords], [np.zeros(n_individuals)]], + [[x_coords + 1], [np.zeros(n_individuals)]], + [[x_coords + 2], [np.zeros(n_individuals)]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_no_keypoints(): + """Return position data without keypoints dimension.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2"] + space = ["x", "y"] + + # Shape: (time=4, space=2, individuals=3) + # All individuals move in +x direction + data = np.array( + [ + [[0, 1, 2], [0, 0, 0]], + [[1, 2, 3], [0, 0, 0]], + [[2, 3, 4], [0, 0, 0]], + [[3, 4, 5], [0, 0, 0]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "individuals"], + coords={ + "time": time, + "space": space, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_multiple_keypoints(): + """Return position data with multiple keypoints for velocity mode test. + + Tests that velocity mode uses the first keypoint. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["nose", "tail", "center"] # nose is first + space = ["x", "y"] + + # nose moves in +x (should be used) + # tail moves in -x + # center moves in +y + # Shape: (time=4, space=2, keypoints=3, individuals=2) + data = np.array( + [ + # time 0 + [ + [[0, 0], [10, 10], [0, 0]], # x: nose, tail, center + [[0, 0], [0, 0], [0, 0]], # y: nose, tail, center + ], + # time 1 + [ + [[1, 1], [9, 9], [0, 0]], + [[0, 0], [0, 0], [1, 1]], + ], + # time 2 + [ + [[2, 2], [8, 8], [0, 0]], + [[0, 0], [0, 0], [2, 2]], + ], + # time 3 + [ + [[3, 3], [7, 7], [0, 0]], + [[0, 0], [0, 0], [3, 3]], + ], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_diagonal_movement(): + """Return position data with diagonal movement at 45 degrees.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both individuals move diagonally (45 degrees, +x +y) + data = np.array( + [ + [[[0, 5]], [[0, 5]]], + [[[1, 6]], [[1, 6]]], + [[[2, 7]], [[2, 7]]], + [[[3, 8]], [[3, 8]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_keypoints_opposite(): + """Return position data where keypoint-based headings are opposite.""" + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["nose", "tail"] + space = ["x", "y"] + + # id_0 faces +x (nose ahead of tail) + # id_1 faces -x (nose behind tail) + data = np.array( + [ + # time 0: id_0 nose at (2,0), tail at (0,0) -> faces +x + # id_1 nose at (3,0), tail at (5,0) -> faces -x + [[[2, 3], [0, 5]], [[0, 0], [0, 0]]], + [[[3, 4], [1, 6]], [[0, 0], [0, 0]]], + [[[4, 5], [2, 7]], [[0, 0], [0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + +@pytest.fixture +def position_data_non_uniform_time(): + """Return position data with non-uniform time spacing.""" + time = [0.0, 0.5, 2.0, 5.0] # Non-uniform intervals + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +x direction + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 6]], [[0, 0]]], + [[[2, 7]], [[0, 0]]], + [[[3, 8]], [[0, 0]]], + ], + dtype=float, + ) + + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + class TestComputePolarization: """Test suite for the compute_polarization function.""" @@ -257,3 +646,526 @@ def test_identical_keypoints(self, position_data_with_keypoints): position_data_with_keypoints, heading_keypoints=("nose", "nose"), ) + + # ==================== Intermediate Polarization Values ==================== + + def test_polarization_perpendicular_four_directions( + self, position_data_perpendicular + ): + """Test polarization is 0.0 when 4 individuals move in 4 cardinal dirs.""" + polarization = kinematics.compute_polarization( + position_data_perpendicular + ) + + # 4 perpendicular directions cancel out -> polarization = 0.0 + # Compare frames 1: to avoid dependence on boundary differencing at t=0. + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_polarization_partial_alignment(self, position_data_partial_alignment): + """Test intermediate polarization with partial alignment. + + Two individuals move +x, one moves +y. + Unit vectors: [1,0], [1,0], [0,1] + Sum: [2, 1] + Magnitude: sqrt(5) + Polarization: sqrt(5)/3 ≈ 0.745 + """ + polarization = kinematics.compute_polarization( + position_data_partial_alignment + ) + + expected = np.sqrt(5) / 3 # ≈ 0.745 + # Compare frames 1: to avoid dependence on boundary differencing at t=0. + assert np.allclose(polarization.values[1:], expected, atol=1e-10) + + def test_polarization_diagonal_movement(self, position_data_diagonal_movement): + """Test polarization with diagonal movement remains 1.0.""" + polarization = kinematics.compute_polarization( + position_data_diagonal_movement + ) + + # Both moving in same diagonal direction -> polarization = 1.0 + # Compare frames 1: to avoid dependence on boundary differencing at t=0. + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + # ==================== Edge Cases ==================== + + def test_polarization_single_individual(self, position_data_single_individual): + """Test polarization is 1.0 for a single individual.""" + polarization = kinematics.compute_polarization( + position_data_single_individual + ) + + # Single individual always has polarization = 1.0 + # (the unit vector divided by 1) + # Compare frames 1: to avoid dependence on boundary differencing at t=0. + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_polarization_all_nan_frame(self, position_data_all_nan_frame): + """Test that all-NaN frames result in NaN polarization.""" + polarization = kinematics.compute_polarization( + position_data_all_nan_frame + ) + + # Frame at index 2 has all NaN positions + # Velocity at frame 2 is computed from frames 1->2 and 2->3 + # Due to NaN at frame 2, velocity at frames 2 and 3 will be affected + # The exact behavior depends on compute_velocity's edge handling + assert isinstance(polarization, xr.DataArray) + # At minimum, verify we get a result with correct length + assert len(polarization) == len(position_data_all_nan_frame.time) + + def test_polarization_stationary(self, position_data_stationary): + """Test that stationary individuals (zero velocity) produce NaN.""" + polarization = kinematics.compute_polarization(position_data_stationary) + + # Zero velocity means zero-length vector -> unit vector is NaN + # All frames after first should be NaN (zero displacement) + assert np.all(np.isnan(polarization.values[1:])) + + def test_polarization_large_n(self, position_data_large_n): + """Test polarization with many individuals (N=50) all aligned.""" + polarization = kinematics.compute_polarization(position_data_large_n) + + # All 50 individuals moving same direction -> polarization = 1.0 + # Compare frames 1: to avoid dependence on boundary differencing at t=0. + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + # ==================== Data Structure Variations ==================== + + def test_polarization_no_keypoints_dimension(self, position_data_no_keypoints): + """Test polarization works without keypoints dimension.""" + polarization = kinematics.compute_polarization(position_data_no_keypoints) + + assert isinstance(polarization, xr.DataArray) + assert polarization.name == "polarization" + # All moving same direction -> polarization = 1.0 + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_polarization_velocity_mode_uses_first_keypoint( + self, position_data_multiple_keypoints + ): + """Test that velocity mode uses the first keypoint when multiple exist.""" + polarization = kinematics.compute_polarization( + position_data_multiple_keypoints + ) + + # First keypoint (nose) moves in +x for both individuals + # So polarization should be 1.0 (not affected by tail or center) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_polarization_keypoints_opposite_directions( + self, position_data_keypoints_opposite + ): + """Test keypoint-based polarization with opposite facing individuals.""" + polarization = kinematics.compute_polarization( + position_data_keypoints_opposite, + heading_keypoints=("tail", "nose"), + ) + + # id_0 faces +x, id_1 faces -x -> polarization = 0.0 + assert np.allclose(polarization.values, 0.0, atol=1e-10) + + # ==================== Time Coordinate Handling ==================== + + def test_polarization_preserves_time_coords( + self, position_data_aligned_individuals + ): + """Test that output preserves time coordinates from input.""" + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + np.testing.assert_array_equal( + polarization.time.values, + position_data_aligned_individuals.time.values, + ) + + def test_polarization_non_uniform_time(self, position_data_non_uniform_time): + """Test polarization with non-uniform time spacing.""" + polarization = kinematics.compute_polarization( + position_data_non_uniform_time + ) + + # Should still work and preserve time coords + expected_times = [0.0, 0.5, 2.0, 5.0] + np.testing.assert_array_equal(polarization.time.values, expected_times) + # Both moving same direction + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + # ==================== Output Properties ==================== + + def test_polarization_output_shape(self, position_data_aligned_individuals): + """Test that output has correct shape (time only).""" + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + assert polarization.dims == ("time",) + assert len(polarization) == len( + position_data_aligned_individuals.time + ) + + def test_polarization_output_no_extra_dims( + self, position_data_aligned_individuals + ): + """Test that output doesn't have keypoints or space dims.""" + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + assert "keypoints" not in polarization.dims + assert "space" not in polarization.dims + assert "individuals" not in polarization.dims + + def test_polarization_first_frame_velocity_mode( + self, position_data_aligned_individuals + ): + """Test first frame behavior when using velocity-based heading. + + The compute_velocity function uses xarray's differentiate which + uses edge_order=1 by default, providing valid values at boundaries. + """ + polarization = kinematics.compute_polarization( + position_data_aligned_individuals + ) + + # First frame should have a valid value due to forward differencing + assert isinstance(polarization, xr.DataArray) + assert len(polarization) == len(position_data_aligned_individuals.time) + + def test_polarization_first_frame_valid_keypoint_mode( + self, position_data_with_keypoints + ): + """Test that first frame is valid when using keypoint-based heading.""" + polarization = kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("tail", "nose"), + ) + + # First frame should be valid (keypoint positions are always known) + assert not np.isnan(polarization.values[0]) + + # ==================== Mathematical Properties ==================== + + def test_polarization_symmetry(self): + """Test that polarization is symmetric (order of individuals irrelevant).""" + time = [0, 1, 2] + keypoints = ["centroid"] + space = ["x", "y"] + + # Create two datasets with same individuals in different order + data1 = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 4]], [[0, 0]]], # id_0: +x, id_1: -x + [[[2, 3]], [[0, 0]]], + ], + dtype=float, + ) + data2 = np.array( + [ + [[[5, 0]], [[0, 0]]], + [[[4, 1]], [[0, 0]]], # id_0: -x, id_1: +x (swapped) + [[[3, 2]], [[0, 0]]], + ], + dtype=float, + ) + + da1 = xr.DataArray( + data1, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": ["id_0", "id_1"], + }, + ) + da2 = xr.DataArray( + data2, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": ["id_0", "id_1"], + }, + ) + + pol1 = kinematics.compute_polarization(da1) + pol2 = kinematics.compute_polarization(da2) + + np.testing.assert_array_almost_equal(pol1.values, pol2.values) + + def test_polarization_bounds_random_directions(self): + """Test polarization stays in [0, 1] with random-ish directions.""" + time = [0, 1, 2, 3, 4] + individuals = [f"id_{i}" for i in range(10)] + keypoints = ["centroid"] + space = ["x", "y"] + + # Create semi-random movement patterns + np.random.seed(42) + n_ind = len(individuals) + n_time = len(time) + + # Random starting positions + x_start = np.random.rand(n_ind) * 100 + y_start = np.random.rand(n_ind) * 100 + + # Random velocities + vx = np.random.randn(n_ind) * 2 + vy = np.random.randn(n_ind) * 2 + + data = np.zeros((n_time, 2, 1, n_ind)) + for t in range(n_time): + data[t, 0, 0, :] = x_start + vx * t + data[t, 1, 0, :] = y_start + vy * t + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization(da) + + valid_values = polarization.values[~np.isnan(polarization.values)] + assert np.all(valid_values >= 0.0) + assert np.all(valid_values <= 1.0) + + # ==================== NaN Handling Edge Cases ==================== + + def test_polarization_nan_one_coordinate_only(self): + """Test handling when only one spatial coordinate is NaN.""" + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_1 has NaN only in x at time 1 + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, np.nan]], [[0, 0]]], # x is NaN, y is valid + [[[2, 7]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization(da) + + # Should still compute - NaN individual is excluded + assert isinstance(polarization, xr.DataArray) + + def test_polarization_nan_in_keypoints(self): + """Test handling NaN in keypoint-based heading calculation.""" + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["nose", "tail"] + space = ["x", "y"] + + # id_1's nose is NaN at time 1 + data = np.array( + [ + [[[2, 5], [0, 3]], [[0, 0], [0, 0]]], + [[[3, np.nan], [1, 4]], [[0, np.nan], [0, 0]]], + [[[4, 7], [2, 5]], [[0, 0], [0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization( + da, heading_keypoints=("tail", "nose") + ) + + # Time 0 and 2 should be valid (both face +x -> 1.0) + assert np.allclose(polarization.values[0], 1.0, atol=1e-10) + assert np.allclose(polarization.values[2], 1.0, atol=1e-10) + # Time 1 should be 1.0 (only id_0 is valid, single individual) + assert np.allclose(polarization.values[1], 1.0, atol=1e-10) + + # ==================== Error Handling ==================== + + def test_missing_space_dimension(self): + """Test that missing space dimension raises ValueError.""" + data = xr.DataArray( + np.random.rand(4, 3), + dims=["time", "individuals"], + coords={"time": [0, 1, 2, 3], "individuals": ["a", "b", "c"]}, + ) + with pytest.raises(ValueError, match="space"): + kinematics.compute_polarization(data) + + def test_missing_time_dimension(self): + """Test that missing time dimension raises ValueError.""" + data = xr.DataArray( + np.random.rand(2, 3), + dims=["space", "individuals"], + coords={"space": ["x", "y"], "individuals": ["a", "b", "c"]}, + ) + with pytest.raises(ValueError, match="time"): + kinematics.compute_polarization(data) + + def test_empty_dataarray(self): + """Test handling of empty DataArray raises an error. + + Empty arrays cause issues in numpy's gradient computation + used by compute_velocity. + """ + data = xr.DataArray( + np.array([]).reshape(0, 2, 0), + dims=["time", "space", "individuals"], + coords={"time": [], "space": ["x", "y"], "individuals": []}, + ) + with pytest.raises((IndexError, ValueError)): + kinematics.compute_polarization(data) + + def test_polarization_mixed_stationary_moving(self): + """Test polarization with some stationary and some moving individuals. + + When some individuals are stationary (zero velocity), they should be + excluded from the polarization calculation (NaN unit vector). + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0: stationary at (0, 0) + # id_1: moves in +x direction + # id_2: moves in +x direction + data = np.array( + [ + [[[0, 5, 10]], [[0, 0, 0]]], + [[[0, 6, 11]], [[0, 0, 0]]], + [[[0, 7, 12]], [[0, 0, 0]]], + [[[0, 8, 13]], [[0, 0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization(da) + + # id_0 is stationary -> NaN heading -> excluded + # id_1 and id_2 both move +x -> polarization = 1.0 + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_polarization_two_time_points_minimum(self): + """Test polarization with minimum time points (2).""" + time = [0, 1] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +x direction + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 6]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization(da) + + assert isinstance(polarization, xr.DataArray) + assert len(polarization) == 2 + # Both moving same direction + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + # ==================== 3D Data (Limitation) ==================== + + def test_3d_data_uses_only_xy(self): + """Test that 3D spatial data only uses x,y coordinates. + + LIMITATION: The current implementation uses validate_dims_coords + with exact_coords=False, so 3D data passes validation but only + x and y coordinates are used for norm/unit vector computation. + The z coordinate is silently ignored. + + This test documents this limitation. + """ + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y", "z"] + + # 3D movement data - both individuals move in +x direction + # z movement is present but will be ignored + data = np.array( + [ + [[[0, 5]], [[0, 0]], [[0, 0]]], + [[[1, 6]], [[0, 0]], [[1, 1]]], + [[[2, 7]], [[0, 0]], [[2, 2]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + # 3D data silently produces results using only x,y + # This is a limitation - z is ignored + polarization = kinematics.compute_polarization(da) + assert isinstance(polarization, xr.DataArray) + # Both moving in same x direction -> polarization = 1.0 (ignoring z) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) From a1896bc14f2c2d14594749f28c6541843dca7ea1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 05:29:00 +0000 Subject: [PATCH 03/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_kinematics/test_collective.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 470f5c79d..7f71313a8 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -244,7 +244,7 @@ def position_data_partial_alignment(): @pytest.fixture def position_data_single_individual(): """Return position data for a single individual. - + In this synthetic dataset, polarization is 1.0 whenever a valid heading can be computed. First-frame behavior in velocity mode depends on boundary differencing. @@ -661,7 +661,9 @@ def test_polarization_perpendicular_four_directions( # Compare frames 1: to avoid dependence on boundary differencing at t=0. assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - def test_polarization_partial_alignment(self, position_data_partial_alignment): + def test_polarization_partial_alignment( + self, position_data_partial_alignment + ): """Test intermediate polarization with partial alignment. Two individuals move +x, one moves +y. @@ -678,7 +680,9 @@ def test_polarization_partial_alignment(self, position_data_partial_alignment): # Compare frames 1: to avoid dependence on boundary differencing at t=0. assert np.allclose(polarization.values[1:], expected, atol=1e-10) - def test_polarization_diagonal_movement(self, position_data_diagonal_movement): + def test_polarization_diagonal_movement( + self, position_data_diagonal_movement + ): """Test polarization with diagonal movement remains 1.0.""" polarization = kinematics.compute_polarization( position_data_diagonal_movement @@ -690,7 +694,9 @@ def test_polarization_diagonal_movement(self, position_data_diagonal_movement): # ==================== Edge Cases ==================== - def test_polarization_single_individual(self, position_data_single_individual): + def test_polarization_single_individual( + self, position_data_single_individual + ): """Test polarization is 1.0 for a single individual.""" polarization = kinematics.compute_polarization( position_data_single_individual @@ -717,7 +723,9 @@ def test_polarization_all_nan_frame(self, position_data_all_nan_frame): def test_polarization_stationary(self, position_data_stationary): """Test that stationary individuals (zero velocity) produce NaN.""" - polarization = kinematics.compute_polarization(position_data_stationary) + polarization = kinematics.compute_polarization( + position_data_stationary + ) # Zero velocity means zero-length vector -> unit vector is NaN # All frames after first should be NaN (zero displacement) @@ -733,9 +741,13 @@ def test_polarization_large_n(self, position_data_large_n): # ==================== Data Structure Variations ==================== - def test_polarization_no_keypoints_dimension(self, position_data_no_keypoints): + def test_polarization_no_keypoints_dimension( + self, position_data_no_keypoints + ): """Test polarization works without keypoints dimension.""" - polarization = kinematics.compute_polarization(position_data_no_keypoints) + polarization = kinematics.compute_polarization( + position_data_no_keypoints + ) assert isinstance(polarization, xr.DataArray) assert polarization.name == "polarization" @@ -781,7 +793,9 @@ def test_polarization_preserves_time_coords( position_data_aligned_individuals.time.values, ) - def test_polarization_non_uniform_time(self, position_data_non_uniform_time): + def test_polarization_non_uniform_time( + self, position_data_non_uniform_time + ): """Test polarization with non-uniform time spacing.""" polarization = kinematics.compute_polarization( position_data_non_uniform_time @@ -795,16 +809,16 @@ def test_polarization_non_uniform_time(self, position_data_non_uniform_time): # ==================== Output Properties ==================== - def test_polarization_output_shape(self, position_data_aligned_individuals): + def test_polarization_output_shape( + self, position_data_aligned_individuals + ): """Test that output has correct shape (time only).""" polarization = kinematics.compute_polarization( position_data_aligned_individuals ) assert polarization.dims == ("time",) - assert len(polarization) == len( - position_data_aligned_individuals.time - ) + assert len(polarization) == len(position_data_aligned_individuals.time) def test_polarization_output_no_extra_dims( self, position_data_aligned_individuals From 00872086df9ece7cd57677ff27dd33fa9d4b874d Mon Sep 17 00:00:00 2001 From: khan-u Date: Mon, 16 Mar 2026 23:08:08 -0700 Subject: [PATCH 04/21] linting fix --- .../test_kinematics/test_collective.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 7f71313a8..dde085ec5 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -163,7 +163,7 @@ def position_data_with_nan(): @pytest.fixture def position_data_perpendicular(): - """Return position data for 4 individuals moving in perpendicular directions. + """Return position data for 4 individuals in perpendicular dirs. Each individual moves in one of the 4 cardinal directions (+x, -x, +y, -y). The sum of unit vectors is zero, so polarization should be 0.0. @@ -310,7 +310,7 @@ def position_data_all_nan_frame(): @pytest.fixture def position_data_stationary(): - """Return position data where individuals are stationary (zero velocity).""" + """Return position data where individuals are stationary.""" time = [0, 1, 2, 3] individuals = ["id_0", "id_1"] keypoints = ["centroid"] @@ -647,18 +647,18 @@ def test_identical_keypoints(self, position_data_with_keypoints): heading_keypoints=("nose", "nose"), ) - # ==================== Intermediate Polarization Values ==================== + # ================= Intermediate Polarization Values ================= def test_polarization_perpendicular_four_directions( self, position_data_perpendicular ): - """Test polarization is 0.0 when 4 individuals move in 4 cardinal dirs.""" + """Test polarization is 0 when 4 individuals move in cardinal dirs.""" polarization = kinematics.compute_polarization( position_data_perpendicular ) # 4 perpendicular directions cancel out -> polarization = 0.0 - # Compare frames 1: to avoid dependence on boundary differencing at t=0. + # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) def test_polarization_partial_alignment( @@ -677,7 +677,7 @@ def test_polarization_partial_alignment( ) expected = np.sqrt(5) / 3 # ≈ 0.745 - # Compare frames 1: to avoid dependence on boundary differencing at t=0. + # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], expected, atol=1e-10) def test_polarization_diagonal_movement( @@ -689,7 +689,7 @@ def test_polarization_diagonal_movement( ) # Both moving in same diagonal direction -> polarization = 1.0 - # Compare frames 1: to avoid dependence on boundary differencing at t=0. + # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) # ==================== Edge Cases ==================== @@ -704,7 +704,7 @@ def test_polarization_single_individual( # Single individual always has polarization = 1.0 # (the unit vector divided by 1) - # Compare frames 1: to avoid dependence on boundary differencing at t=0. + # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) def test_polarization_all_nan_frame(self, position_data_all_nan_frame): @@ -736,7 +736,7 @@ def test_polarization_large_n(self, position_data_large_n): polarization = kinematics.compute_polarization(position_data_large_n) # All 50 individuals moving same direction -> polarization = 1.0 - # Compare frames 1: to avoid dependence on boundary differencing at t=0. + # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) # ==================== Data Structure Variations ==================== @@ -757,7 +757,7 @@ def test_polarization_no_keypoints_dimension( def test_polarization_velocity_mode_uses_first_keypoint( self, position_data_multiple_keypoints ): - """Test that velocity mode uses the first keypoint when multiple exist.""" + """Test velocity mode uses first keypoint when multiple exist.""" polarization = kinematics.compute_polarization( position_data_multiple_keypoints ) @@ -769,7 +769,7 @@ def test_polarization_velocity_mode_uses_first_keypoint( def test_polarization_keypoints_opposite_directions( self, position_data_keypoints_opposite ): - """Test keypoint-based polarization with opposite facing individuals.""" + """Test keypoint polarization with opposite facing individuals.""" polarization = kinematics.compute_polarization( position_data_keypoints_opposite, heading_keypoints=("tail", "nose"), @@ -863,7 +863,7 @@ def test_polarization_first_frame_valid_keypoint_mode( # ==================== Mathematical Properties ==================== def test_polarization_symmetry(self): - """Test that polarization is symmetric (order of individuals irrelevant).""" + """Test polarization is symmetric (individual order irrelevant).""" time = [0, 1, 2] keypoints = ["centroid"] space = ["x", "y"] From 0af9dc77d1b34c20c23311c5fa7ac04704823b06 Mon Sep 17 00:00:00 2001 From: khan-u Date: Tue, 17 Mar 2026 18:49:10 -0700 Subject: [PATCH 05/21] test(collective): consolidate related polarization tests to reduce redundancy --- .../test_kinematics/test_collective.py | 92 ++++++++++--------- 1 file changed, 51 insertions(+), 41 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index dde085ec5..cb01ec962 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -553,8 +553,18 @@ def position_data_non_uniform_time(): class TestComputePolarization: """Test suite for the compute_polarization function.""" - def test_polarization_aligned(self, position_data_aligned_individuals): - """Test polarization is 1.0 when all move same direction.""" + def test_polarization_aligned( + self, + position_data_aligned_individuals, + position_data_diagonal_movement, + ): + """Test polarization is 1.0 when all move same direction. + + Tests both horizontal and diagonal movement to verify that + polarization is rotation-invariant (direction angle doesn't matter, + only alignment between individuals). + """ + # Test horizontal alignment polarization = kinematics.compute_polarization( position_data_aligned_individuals ) @@ -569,6 +579,13 @@ def test_polarization_aligned(self, position_data_aligned_individuals): # (Skip first time point since velocity is computed via diff) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + # Test diagonal alignment (rotation invariance) + # Both individuals moving at 45 degrees should also yield pol=1.0 + polarization_diag = kinematics.compute_polarization( + position_data_diagonal_movement + ) + assert np.allclose(polarization_diag.values[1:], 1.0, atol=1e-10) + def test_polarization_opposite(self, position_data_opposite_individuals): """Test polarization is 0.0 when individuals move opposite.""" polarization = kinematics.compute_polarization( @@ -604,17 +621,6 @@ def test_polarization_handles_nan(self, position_data_with_nan): # The frame with NaN should exclude that individual from calculation assert not np.all(np.isnan(polarization.values)) - def test_polarization_range(self, position_data_aligned_individuals): - """Test that polarization values are in [0, 1] range.""" - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - - # Exclude NaN values from range check - valid_values = polarization.values[~np.isnan(polarization.values)] - assert np.all(valid_values >= 0.0) - assert np.all(valid_values <= 1.0) - def test_invalid_input_type(self, position_data_aligned_individuals): """Test that non-DataArray input raises TypeError.""" with pytest.raises(TypeError, match="must be an xarray.DataArray"): @@ -680,18 +686,6 @@ def test_polarization_partial_alignment( # Compare frames 1: avoid boundary differencing dependence at t=0. assert np.allclose(polarization.values[1:], expected, atol=1e-10) - def test_polarization_diagonal_movement( - self, position_data_diagonal_movement - ): - """Test polarization with diagonal movement remains 1.0.""" - polarization = kinematics.compute_polarization( - position_data_diagonal_movement - ) - - # Both moving in same diagonal direction -> polarization = 1.0 - # Compare frames 1: avoid boundary differencing dependence at t=0. - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - # ==================== Edge Cases ==================== def test_polarization_single_individual( @@ -809,25 +803,23 @@ def test_polarization_non_uniform_time( # ==================== Output Properties ==================== - def test_polarization_output_shape( + def test_polarization_output_structure( self, position_data_aligned_individuals ): - """Test that output has correct shape (time only).""" + """Test that output has correct structure (time dimension only). + + Verifies both positive assertion (dims == time) and explicit + absence of input dimensions that should be reduced over. + """ polarization = kinematics.compute_polarization( position_data_aligned_individuals ) + # Positive assertion: output has exactly time dimension assert polarization.dims == ("time",) assert len(polarization) == len(position_data_aligned_individuals.time) - def test_polarization_output_no_extra_dims( - self, position_data_aligned_individuals - ): - """Test that output doesn't have keypoints or space dims.""" - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - + # Explicit absence checks (documents which dims are reduced) assert "keypoints" not in polarization.dims assert "space" not in polarization.dims assert "individuals" not in polarization.dims @@ -912,8 +904,24 @@ def test_polarization_symmetry(self): np.testing.assert_array_almost_equal(pol1.values, pol2.values) - def test_polarization_bounds_random_directions(self): - """Test polarization stays in [0, 1] with random-ish directions.""" + def test_polarization_bounds(self, position_data_aligned_individuals): + """Test polarization values are always in [0, 1] range. + + Verifies bounds with both: + 1. Simple aligned data (deterministic, yields boundary value 1.0) + 2. Random directions (stochastic, yields distribution across range) + """ + # Test with simple aligned data (boundary case: all 1.0) + polarization_simple = kinematics.compute_polarization( + position_data_aligned_individuals + ) + valid_simple = polarization_simple.values[ + ~np.isnan(polarization_simple.values) + ] + assert np.all(valid_simple >= 0.0) + assert np.all(valid_simple <= 1.0) + + # Test with random directions (interior values) time = [0, 1, 2, 3, 4] individuals = [f"id_{i}" for i in range(10)] keypoints = ["centroid"] @@ -948,11 +956,13 @@ def test_polarization_bounds_random_directions(self): }, ) - polarization = kinematics.compute_polarization(da) + polarization_random = kinematics.compute_polarization(da) - valid_values = polarization.values[~np.isnan(polarization.values)] - assert np.all(valid_values >= 0.0) - assert np.all(valid_values <= 1.0) + valid_random = polarization_random.values[ + ~np.isnan(polarization_random.values) + ] + assert np.all(valid_random >= 0.0) + assert np.all(valid_random <= 1.0) # ==================== NaN Handling Edge Cases ==================== From 758ea0420764fa2978e62f89a9dfceb22074b2c9 Mon Sep 17 00:00:00 2001 From: khan-u Date: Tue, 17 Mar 2026 21:50:28 -0700 Subject: [PATCH 06/21] feat(kinematics): add displacement_frames and return_angle params to compute_polarization --- movement/kinematics/collective.py | 89 ++- .../test_kinematics/test_collective.py | 565 +++++++++++++++++- 2 files changed, 620 insertions(+), 34 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 01f70ed7b..db22ff5b1 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -5,7 +5,6 @@ import numpy as np import xarray as xr -from movement.kinematics.kinematics import compute_velocity from movement.utils.logging import logger from movement.utils.vector import compute_norm, convert_to_unit from movement.validators.arrays import validate_dims_coords @@ -14,7 +13,9 @@ def compute_polarization( data: xr.DataArray, heading_keypoints: tuple[Hashable, Hashable] | None = None, -) -> xr.DataArray: + displacement_frames: int = 1, + return_angle: bool = False, +) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: r"""Compute the polarization (group alignment) of multiple individuals. Polarization measures how aligned the heading directions of individuals @@ -38,22 +39,36 @@ def compute_polarization( A tuple of two keypoint names ``(origin, target)`` used to compute the heading direction as the vector from origin to target (e.g., ``("neck", "nose")`` or ``("tail", "head")``). - If None, heading is inferred from the velocity of the first + If None, heading is inferred from the displacement of the first available keypoint. + displacement_frames : int, optional + Number of frames over which to compute displacement when + ``heading_keypoints`` is None. Default is 1 (frame-to-frame + displacement). Use higher values for smoother heading estimates + (e.g., fps value for 1-second displacement). This parameter is + ignored when ``heading_keypoints`` is provided. + return_angle : bool, optional + If True, also return the mean heading angle (in radians) of the + group at each time point. Default is False. Returns ------- - xarray.DataArray - An xarray DataArray containing the polarization value at each - time point, with dimensions ``(time,)``. Values range from 0 - (random orientations) to 1 (perfectly aligned). + xarray.DataArray or tuple of xarray.DataArray + If ``return_angle`` is False (default), returns an xarray DataArray + containing the polarization value at each time point, with + dimensions ``(time,)``. Values range from 0 (random orientations) + to 1 (perfectly aligned). + + If ``return_angle`` is True, returns a tuple of two DataArrays: + ``(polarization, mean_angle)``, where ``mean_angle`` contains the + mean heading direction in radians at each time point. Notes ----- If ``heading_keypoints`` is provided, the heading for each individual is computed as the unit vector from the origin to the target - keypoint. If not provided, heading is inferred from the instantaneous - velocity direction. + keypoint. If not provided, heading is inferred from the displacement + direction over ``displacement_frames`` frames. Frames where an individual has missing data (NaN) are handled by excluding that individual from the polarization calculation for that @@ -68,13 +83,23 @@ def compute_polarization( ... heading_keypoints=("neck", "nose"), ... ) - Compute polarization using velocity-inferred heading: + Compute polarization using displacement-inferred heading: >>> polarization = compute_polarization(ds.position) - See Also - -------- - movement.kinematics.compute_velocity : Compute velocity from position. + Compute polarization with 1-second displacement (at 30 fps): + + >>> polarization = compute_polarization( + ... ds.position, + ... displacement_frames=30, + ... ) + + Also return the mean heading angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... return_angle=True, + ... ) """ # Validate input data @@ -94,7 +119,9 @@ def compute_polarization( data, heading_keypoints ) else: - heading_vectors = _compute_heading_from_velocity(data) + heading_vectors = _compute_heading_from_velocity( + data, displacement_frames=displacement_frames + ) # Convert to unit vectors unit_headings = convert_to_unit(heading_vectors) @@ -116,6 +143,18 @@ def compute_polarization( polarization = xr.where(n_valid > 0, sum_magnitude / n_valid, np.nan) polarization.name = "polarization" + + if return_angle: + # Compute mean heading angle from the vector sum + # arctan2(y, x) gives angle in radians + mean_angle = np.arctan2( + vector_sum.sel(space="y"), + vector_sum.sel(space="x"), + ) + mean_angle = xr.where(n_valid > 0, mean_angle, np.nan) + mean_angle.name = "mean_angle" + return polarization, mean_angle + return polarization @@ -161,8 +200,11 @@ def _compute_heading_from_keypoints( return heading -def _compute_heading_from_velocity(data: xr.DataArray) -> xr.DataArray: - """Compute heading vectors from velocity (displacement direction). +def _compute_heading_from_velocity( + data: xr.DataArray, + displacement_frames: int = 1, +) -> xr.DataArray: + """Compute heading vectors from displacement direction. Uses the first available keypoint if multiple are present. @@ -170,11 +212,15 @@ def _compute_heading_from_velocity(data: xr.DataArray) -> xr.DataArray: ---------- data : xarray.DataArray Position data with ``time`` dimension. + displacement_frames : int, optional + Number of frames over which to compute displacement. Default is 1 + (frame-to-frame displacement). Use higher values for smoother + heading estimates (e.g., fps for 1-second displacement). Returns ------- xarray.DataArray - Heading vectors based on velocity direction. + Heading vectors based on displacement direction. """ # If keypoints dimension exists, use first keypoint @@ -182,15 +228,16 @@ def _compute_heading_from_velocity(data: xr.DataArray) -> xr.DataArray: first_keypoint = data.keypoints.values[0] position = data.sel(keypoints=first_keypoint, drop=True) logger.info( - f"Using keypoint '{first_keypoint}' for velocity-based heading." + f"Using keypoint '{first_keypoint}' for displacement-based heading." ) else: position = data - # Compute velocity as heading direction - velocity = compute_velocity(position) + # Compute displacement over N frames + # displacement[t] = position[t] - position[t - displacement_frames] + displacement = position - position.shift(time=displacement_frames) - return velocity + return displacement def _validate_type_data_array(data: xr.DataArray) -> None: diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index cb01ec962..2a54e0295 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -824,21 +824,25 @@ def test_polarization_output_structure( assert "space" not in polarization.dims assert "individuals" not in polarization.dims - def test_polarization_first_frame_velocity_mode( + def test_polarization_first_frame_displacement_mode( self, position_data_aligned_individuals ): - """Test first frame behavior when using velocity-based heading. + """Test first frame behavior when using displacement-based heading. - The compute_velocity function uses xarray's differentiate which - uses edge_order=1 by default, providing valid values at boundaries. + Displacement-based heading requires position at frame t and t-N + (where N=displacement_frames). Frame 0 has no prior reference, + so it is NaN. This is consistent and predictable behavior. """ polarization = kinematics.compute_polarization( position_data_aligned_individuals ) - # First frame should have a valid value due to forward differencing assert isinstance(polarization, xr.DataArray) assert len(polarization) == len(position_data_aligned_individuals.time) + # First frame is NaN (no t-1 reference for displacement) + assert np.isnan(polarization.values[0]) + # Subsequent frames are valid + assert not np.any(np.isnan(polarization.values[1:])) def test_polarization_first_frame_valid_keypoint_mode( self, position_data_with_keypoints @@ -1060,18 +1064,20 @@ def test_missing_time_dimension(self): kinematics.compute_polarization(data) def test_empty_dataarray(self): - """Test handling of empty DataArray raises an error. + """Test handling of empty DataArray returns empty result. - Empty arrays cause issues in numpy's gradient computation - used by compute_velocity. + Empty arrays are handled gracefully by the displacement-based + implementation, returning an empty polarization array. """ data = xr.DataArray( np.array([]).reshape(0, 2, 0), dims=["time", "space", "individuals"], coords={"time": [], "space": ["x", "y"], "individuals": []}, ) - with pytest.raises((IndexError, ValueError)): - kinematics.compute_polarization(data) + polarization = kinematics.compute_polarization(data) + + assert isinstance(polarization, xr.DataArray) + assert len(polarization) == 0 def test_polarization_mixed_stationary_moving(self): """Test polarization with some stationary and some moving individuals. @@ -1115,7 +1121,11 @@ def test_polarization_mixed_stationary_moving(self): assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) def test_polarization_two_time_points_minimum(self): - """Test polarization with minimum time points (2).""" + """Test polarization with minimum time points (2). + + With displacement-based heading, frame 0 is NaN (no prior reference). + Frame 1 has valid displacement from frame 0 to frame 1. + """ time = [0, 1] individuals = ["id_0", "id_1"] keypoints = ["centroid"] @@ -1145,8 +1155,10 @@ def test_polarization_two_time_points_minimum(self): assert isinstance(polarization, xr.DataArray) assert len(polarization) == 2 - # Both moving same direction - assert np.allclose(polarization.values, 1.0, atol=1e-10) + # Frame 0 is NaN (no prior reference for displacement) + assert np.isnan(polarization.values[0]) + # Frame 1 is valid: both moving same direction -> polarization = 1.0 + assert np.allclose(polarization.values[1], 1.0, atol=1e-10) # ==================== 3D Data (Limitation) ==================== @@ -1193,3 +1205,530 @@ def test_3d_data_uses_only_xy(self): assert isinstance(polarization, xr.DataArray) # Both moving in same x direction -> polarization = 1.0 (ignoring z) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + # ==================== displacement_frames Parameter ==================== + + def test_displacement_frames_default( + self, position_data_aligned_individuals + ): + """Test displacement_frames=1 (default) matches original behavior.""" + # Explicit default should match implicit default + pol_implicit = kinematics.compute_polarization( + position_data_aligned_individuals + ) + pol_explicit = kinematics.compute_polarization( + position_data_aligned_individuals, + displacement_frames=1, + ) + + np.testing.assert_array_equal(pol_implicit.values, pol_explicit.values) + + def test_displacement_frames_multi_frame(self): + """Test multi-frame displacement computes heading over N frames. + + With displacement_frames=2, heading is computed from position[t-2] + to position[t], giving smoother estimates over longer intervals. + """ + time = [0, 1, 2, 3, 4, 5] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +x direction consistently + data = np.array( + [ + [[[0, 10]], [[0, 0]]], + [[[1, 11]], [[0, 0]]], + [[[2, 12]], [[0, 0]]], + [[[3, 13]], [[0, 0]]], + [[[4, 14]], [[0, 0]]], + [[[5, 15]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization( + da, displacement_frames=2 + ) + + # First 2 frames are NaN (no t-2 reference) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + # Frames 2+ should have valid polarization = 1.0 + assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) + + def test_displacement_frames_larger_than_timeseries(self): + """Test displacement_frames larger than available frames yields NaN.""" + time = [0, 1, 2] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 6]], [[0, 0]]], + [[[2, 7]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + # displacement_frames=10 but only 3 time points + polarization = kinematics.compute_polarization( + da, displacement_frames=10 + ) + + # All frames should be NaN (no valid t-10 reference) + assert np.all(np.isnan(polarization.values)) + + def test_displacement_frames_ignored_with_keypoints( + self, position_data_with_keypoints + ): + """Test displacement_frames is ignored when heading_keypoints given.""" + pol_default = kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("tail", "nose"), + displacement_frames=1, + ) + pol_large = kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("tail", "nose"), + displacement_frames=100, # Should be ignored + ) + + # Both should be identical since keypoints override displacement + np.testing.assert_array_equal(pol_default.values, pol_large.values) + # All frames valid (keypoint heading doesn't need displacement) + assert np.allclose(pol_default.values, 1.0, atol=1e-10) + + def test_displacement_frames_with_nan_at_reference(self): + """Test NaN at reference frame (t - displacement_frames) propagates.""" + time = [0, 1, 2, 3, 4] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # NaN at frame 1 + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[np.nan, np.nan]], [[np.nan, np.nan]]], # NaN frame + [[[2, 7]], [[0, 0]]], + [[[3, 8]], [[0, 0]]], + [[[4, 9]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + polarization = kinematics.compute_polarization( + da, displacement_frames=2 + ) + + # Frame 3 references frame 1 (NaN) -> should be NaN + assert np.isnan(polarization.values[3]) + # Frame 4 references frame 2 (valid) -> should be valid + assert not np.isnan(polarization.values[4]) + + def test_displacement_frames_smooths_noisy_trajectory(self): + """Test larger displacement_frames smooths jittery movement.""" + time = list(range(10)) + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Noisy +x movement with small y jitter + np.random.seed(42) + jitter = np.random.randn(10, 2) * 0.1 + + data = np.zeros((10, 2, 1, 2)) + for t in range(10): + # Both individuals move +x with small y noise + data[t, 0, 0, :] = [t + jitter[t, 0], t + jitter[t, 1]] + data[t, 1, 0, :] = [jitter[t, 0], jitter[t, 1]] + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + pol_1frame = kinematics.compute_polarization(da, displacement_frames=1) + pol_5frame = kinematics.compute_polarization(da, displacement_frames=5) + + # 5-frame displacement should yield higher polarization (less noise) + valid_1 = pol_1frame.values[~np.isnan(pol_1frame.values)] + valid_5 = pol_5frame.values[~np.isnan(pol_5frame.values)] + + # Mean polarization should be higher with smoothing + assert np.mean(valid_5) >= np.mean(valid_1) - 0.1 # Allow small margin + + # ==================== return_angle Parameter ==================== + + def test_return_angle_false_returns_dataarray( + self, position_data_aligned_individuals + ): + """Test return_angle=False returns single DataArray (default).""" + result = kinematics.compute_polarization( + position_data_aligned_individuals, + return_angle=False, + ) + + assert isinstance(result, xr.DataArray) + assert result.name == "polarization" + + def test_return_angle_true_returns_tuple( + self, position_data_aligned_individuals + ): + """Test return_angle=True returns tuple of (polarization, angle).""" + result = kinematics.compute_polarization( + position_data_aligned_individuals, + return_angle=True, + ) + + assert isinstance(result, tuple) + assert len(result) == 2 + + polarization, mean_angle = result + assert isinstance(polarization, xr.DataArray) + assert isinstance(mean_angle, xr.DataArray) + assert polarization.name == "polarization" + assert mean_angle.name == "mean_angle" + + def test_return_angle_positive_x_direction(self): + """Test mean_angle is 0 radians for +x movement.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +x direction + data = np.array( + [ + [[[0, 5]], [[0, 0]]], + [[[1, 6]], [[0, 0]]], + [[[2, 7]], [[0, 0]]], + [[[3, 8]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + _, mean_angle = kinematics.compute_polarization(da, return_angle=True) + + # +x direction = 0 radians + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + def test_return_angle_positive_y_direction(self): + """Test mean_angle is π/2 radians for +y movement.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +y direction + data = np.array( + [ + [[[0, 0]], [[0, 5]]], + [[[0, 0]], [[1, 6]]], + [[[0, 0]], [[2, 7]]], + [[[0, 0]], [[3, 8]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + _, mean_angle = kinematics.compute_polarization(da, return_angle=True) + + # +y direction = π/2 radians + assert np.allclose(mean_angle.values[1:], np.pi / 2, atol=1e-10) + + def test_return_angle_negative_x_direction(self): + """Test mean_angle is ±π radians for -x movement.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in -x direction + data = np.array( + [ + [[[10, 15]], [[0, 0]]], + [[[9, 14]], [[0, 0]]], + [[[8, 13]], [[0, 0]]], + [[[7, 12]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + _, mean_angle = kinematics.compute_polarization(da, return_angle=True) + + # -x direction = ±π radians (arctan2 returns π or -π) + assert np.allclose(np.abs(mean_angle.values[1:]), np.pi, atol=1e-10) + + def test_return_angle_diagonal_45_degrees(self): + """Test mean_angle is π/4 radians for diagonal (+x, +y) movement.""" + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move diagonally (+x, +y) at 45 degrees + data = np.array( + [ + [[[0, 5]], [[0, 5]]], + [[[1, 6]], [[1, 6]]], + [[[2, 7]], [[2, 7]]], + [[[3, 8]], [[3, 8]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + _, mean_angle = kinematics.compute_polarization(da, return_angle=True) + + # 45 degrees = π/4 radians + assert np.allclose(mean_angle.values[1:], np.pi / 4, atol=1e-10) + + def test_return_angle_nan_when_no_valid_individuals( + self, position_data_all_nan_frame + ): + """Test mean_angle is NaN when no valid individuals.""" + _, mean_angle = kinematics.compute_polarization( + position_data_all_nan_frame, + return_angle=True, + ) + + # Frame 2 has all NaN -> angle should be NaN for affected frames + assert isinstance(mean_angle, xr.DataArray) + # At least some frames should be NaN due to the all-NaN frame + assert np.any(np.isnan(mean_angle.values)) + + def test_return_angle_with_keypoint_heading( + self, position_data_with_keypoints + ): + """Test return_angle works with keypoint-based heading.""" + pol, angle = kinematics.compute_polarization( + position_data_with_keypoints, + heading_keypoints=("tail", "nose"), + return_angle=True, + ) + + # Both face +x direction (nose ahead of tail in x) + assert np.allclose(pol.values, 1.0, atol=1e-10) + assert np.allclose(angle.values, 0.0, atol=1e-10) + + def test_return_angle_partial_alignment_mean_direction(self): + """Test mean_angle reflects weighted mean of individual headings. + + Two individuals move +x (angle=0), one moves +y (angle=π/2). + Mean unit vector: [2, 1] / 3, but angle is arctan2(1, 2) ≈ 0.464 rad. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1", "id_2"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0 and id_1 move +x, id_2 moves +y + data = np.array( + [ + [[[0, 5, 0]], [[0, 0, 0]]], + [[[1, 6, 0]], [[0, 0, 1]]], + [[[2, 7, 0]], [[0, 0, 2]]], + [[[3, 8, 0]], [[0, 0, 3]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + pol, angle = kinematics.compute_polarization(da, return_angle=True) + + # Sum of unit vectors: [1,0] + [1,0] + [0,1] = [2, 1] + # Mean angle: arctan2(1, 2) ≈ 0.4636 rad ≈ 26.57 degrees + expected_angle = np.arctan2(1, 2) + assert np.allclose(angle.values[1:], expected_angle, atol=1e-10) + + def test_return_angle_opposite_directions_undefined(self): + """Test mean_angle when individuals move in opposite directions. + + When vectors cancel out (polarization ≈ 0), angle is still computed + from the sum vector but may be arbitrary for perfectly opposed motion. + """ + time = [0, 1, 2, 3] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # id_0 moves +x, id_1 moves -x (exactly opposite) + data = np.array( + [ + [[[0, 10]], [[0, 0]]], + [[[1, 9]], [[0, 0]]], + [[[2, 8]], [[0, 0]]], + [[[3, 7]], [[0, 0]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + pol, angle = kinematics.compute_polarization(da, return_angle=True) + + # Polarization should be 0 (opposite directions cancel) + assert np.allclose(pol.values[1:], 0.0, atol=1e-10) + # Angle is computed from [0, 0] sum vector + # arctan2(0, 0) = 0 in numpy + assert np.allclose(angle.values[1:], 0.0, atol=1e-10) + + def test_return_angle_with_displacement_frames(self): + """Test return_angle works correctly with multi-frame displacement.""" + time = [0, 1, 2, 3, 4, 5] + individuals = ["id_0", "id_1"] + keypoints = ["centroid"] + space = ["x", "y"] + + # Both move in +y direction + data = np.array( + [ + [[[0, 0]], [[0, 5]]], + [[[0, 0]], [[1, 6]]], + [[[0, 0]], [[2, 7]]], + [[[0, 0]], [[3, 8]]], + [[[0, 0]], [[4, 9]]], + [[[0, 0]], [[5, 10]]], + ], + dtype=float, + ) + + da = xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time, + "space": space, + "keypoints": keypoints, + "individuals": individuals, + }, + ) + + pol, angle = kinematics.compute_polarization( + da, + displacement_frames=2, + return_angle=True, + ) + + # First 2 frames are NaN + assert np.isnan(angle.values[0]) + assert np.isnan(angle.values[1]) + # Frames 2+ should be +y direction = π/2 + assert np.allclose(angle.values[2:], np.pi / 2, atol=1e-10) + + def test_return_angle_dimensions_match_polarization( + self, position_data_aligned_individuals + ): + """Test mean_angle has same dimensions and coords as polarization.""" + pol, angle = kinematics.compute_polarization( + position_data_aligned_individuals, + return_angle=True, + ) + + assert pol.dims == angle.dims + assert len(pol) == len(angle) + np.testing.assert_array_equal(pol.time.values, angle.time.values) From e271a8c5f72226d882a33f21b3e745f3cbc5a4e9 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 00:03:08 -0700 Subject: [PATCH 07/21] refactor(collective): Rewrite compute_polarization with robust validation, edge case handling, and simplified tests --- movement/kinematics/collective.py | 349 +-- .../test_kinematics/test_collective.py | 2104 +++++------------ 2 files changed, 768 insertions(+), 1685 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index db22ff5b1..bb3258a92 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -1,14 +1,18 @@ +# collective.py """Compute collective behavior metrics for multi-individual tracking data.""" from collections.abc import Hashable +from typing import Any import numpy as np import xarray as xr from movement.utils.logging import logger -from movement.utils.vector import compute_norm, convert_to_unit +from movement.utils.vector import compute_norm from movement.validators.arrays import validate_dims_coords +_ANGLE_EPS = 1e-12 + def compute_polarization( data: xr.DataArray, @@ -16,85 +20,78 @@ def compute_polarization( displacement_frames: int = 1, return_angle: bool = False, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: - r"""Compute the polarization (group alignment) of multiple individuals. + r"""Compute polarization (group alignment) of individuals. Polarization measures how aligned the heading directions of individuals - are. A value of 1 indicates all individuals are heading in the same - direction, while a value near 0 indicates random orientations. + are. A value of 1 indicates perfect alignment, while a value near 0 + indicates weak or canceling alignment. + + The polarization is computed as - The polarization is computed as: + .. math:: - .. math:: \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{v}_i \right\| + \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{v}_i \right\| where :math:`\hat{v}_i` is the unit heading vector for individual - :math:`i`, and :math:`N` is the number of individuals. + :math:`i`, and :math:`N` is the number of valid individuals at that time. Parameters ---------- data : xarray.DataArray - The input data representing position. Must contain ``time``, - ``space``, and ``individuals`` as dimensions. The ``keypoints`` - dimension is required only if ``heading_keypoints`` is provided. - heading_keypoints : tuple of Hashable, optional - A tuple of two keypoint names ``(origin, target)`` used to - compute the heading direction as the vector from origin to - target (e.g., ``("neck", "nose")`` or ``("tail", "head")``). - If None, heading is inferred from the displacement of the first - available keypoint. - displacement_frames : int, optional - Number of frames over which to compute displacement when - ``heading_keypoints`` is None. Default is 1 (frame-to-frame - displacement). Use higher values for smoother heading estimates - (e.g., fps value for 1-second displacement). This parameter is - ignored when ``heading_keypoints`` is provided. - return_angle : bool, optional - If True, also return the mean heading angle (in radians) of the - group at each time point. Default is False. + Position data. Must contain ``time``, ``space``, and ``individuals`` as + dimensions. If ``heading_keypoints`` is provided, the array must also + contain a ``keypoints`` dimension. + + Spatial coordinates must include ``"x"`` and ``"y"``. If additional + spatial coordinates are present (e.g., ``"z"``), they are ignored; + polarization is computed in the x/y plane. + heading_keypoints : tuple[Hashable, Hashable], optional + Pair of keypoint names ``(origin, target)`` used to compute heading as + the vector from origin to target. If omitted, heading is inferred from + displacement over ``displacement_frames``. + displacement_frames : int, default=1 + Number of frames used to compute displacement when + ``heading_keypoints`` is not provided. Must be a positive integer. + This parameter is ignored when ``heading_keypoints`` is provided. + return_angle : bool, default=False + If True, also return the mean heading angle in radians. Returns ------- - xarray.DataArray or tuple of xarray.DataArray - If ``return_angle`` is False (default), returns an xarray DataArray - containing the polarization value at each time point, with - dimensions ``(time,)``. Values range from 0 (random orientations) - to 1 (perfectly aligned). + xarray.DataArray or tuple[xarray.DataArray, xarray.DataArray] + If ``return_angle`` is False, returns a DataArray named + ``"polarization"`` with dimension ``("time",)``. - If ``return_angle`` is True, returns a tuple of two DataArrays: - ``(polarization, mean_angle)``, where ``mean_angle`` contains the - mean heading direction in radians at each time point. + If ``return_angle`` is True, returns + ``(polarization, mean_angle)`` where ``mean_angle`` is a DataArray + named ``"mean_angle"`` with dimension ``("time",)``. Notes ----- - If ``heading_keypoints`` is provided, the heading for each individual - is computed as the unit vector from the origin to the target - keypoint. If not provided, heading is inferred from the displacement - direction over ``displacement_frames`` frames. + Missing data are excluded per individual, per frame. + + Zero-length headings are treated as invalid and excluded from the + calculation. - Frames where an individual has missing data (NaN) are handled by - excluding that individual from the polarization calculation for that - frame. + The mean heading angle is defined from the summed unit-heading vector + projected onto the x/y plane. When no valid headings exist, or when the + summed heading vector has zero magnitude (for example exact cancellation), + the returned angle is NaN. Examples -------- - Compute polarization using two keypoints to define heading: - - >>> polarization = compute_polarization( - ... ds.position, - ... heading_keypoints=("neck", "nose"), - ... ) - - Compute polarization using displacement-inferred heading: + Compute polarization from displacement: >>> polarization = compute_polarization(ds.position) - Compute polarization with 1-second displacement (at 30 fps): + Compute polarization from keypoint-defined heading: >>> polarization = compute_polarization( ... ds.position, - ... displacement_frames=30, + ... heading_keypoints=("tail", "nose"), ... ) - Also return the mean heading angle: + Return both polarization and mean angle: >>> polarization, mean_angle = compute_polarization( ... ds.position, @@ -102,162 +99,194 @@ def compute_polarization( ... ) """ - # Validate input data _validate_type_data_array(data) - validate_dims_coords( - data, - { - "time": [], - "space": [], - "individuals": [], - }, + normalized_keypoints = _validate_position_data( + data=data, + heading_keypoints=heading_keypoints, ) - # Compute heading vectors for all individuals - if heading_keypoints is not None: + if normalized_keypoints is not None: heading_vectors = _compute_heading_from_keypoints( - data, heading_keypoints + data=data, + heading_keypoints=normalized_keypoints, ) else: heading_vectors = _compute_heading_from_velocity( - data, displacement_frames=displacement_frames + data=data, + displacement_frames=displacement_frames, ) - # Convert to unit vectors - unit_headings = convert_to_unit(heading_vectors) + heading_xy = _select_xy(heading_vectors) + norm = compute_norm(heading_xy) + valid_mask = (~heading_xy.isnull().any(dim="space")) & (norm > 0) - # Sum unit vectors across individuals - # Use nansum to handle missing data + unit_headings = (heading_xy / norm).where(valid_mask) vector_sum = unit_headings.sum(dim="individuals", skipna=True) - - # Count valid (non-NaN) individuals per time point - # A heading is valid if both x and y are not NaN - valid_mask = ~unit_headings.isnull().any(dim="space") - n_valid = valid_mask.sum(dim="individuals") - - # Compute magnitude of the sum sum_magnitude = compute_norm(vector_sum) + n_valid = valid_mask.sum(dim="individuals") - # Normalize by number of valid individuals - # Avoid division by zero - polarization = xr.where(n_valid > 0, sum_magnitude / n_valid, np.nan) + polarization = xr.where( + n_valid > 0, + sum_magnitude / n_valid, + np.nan, + ).clip(min=0.0, max=1.0) + polarization = polarization.rename("polarization") - polarization.name = "polarization" + if not return_angle: + return polarization - if return_angle: - # Compute mean heading angle from the vector sum - # arctan2(y, x) gives angle in radians - mean_angle = np.arctan2( + angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) + mean_angle = xr.where( + angle_defined, + np.arctan2( vector_sum.sel(space="y"), vector_sum.sel(space="x"), - ) - mean_angle = xr.where(n_valid > 0, mean_angle, np.nan) - mean_angle.name = "mean_angle" - return polarization, mean_angle + ), + np.nan, + ).rename("mean_angle") - return polarization + return polarization, mean_angle def _compute_heading_from_keypoints( data: xr.DataArray, heading_keypoints: tuple[Hashable, Hashable], ) -> xr.DataArray: - """Compute heading vectors from two keypoints (origin to target). + """Compute heading vectors from two keypoints (origin to target).""" + origin, target = heading_keypoints + heading = data.sel(keypoints=target, drop=True) - data.sel( + keypoints=origin, + drop=True, + ) + return heading - Parameters - ---------- - data : xarray.DataArray - Position data with ``keypoints`` dimension. - heading_keypoints : tuple of Hashable - A tuple of ``(origin, target)`` keypoint names. The heading - vector points from origin toward target. - Returns - ------- - xarray.DataArray - Heading vectors with dimensions ``(time, space, individuals)``. +def _compute_heading_from_velocity( + data: xr.DataArray, + displacement_frames: int = 1, +) -> xr.DataArray: + """Compute heading vectors from displacement direction.""" + _validate_displacement_frames(displacement_frames) - """ - origin, target = heading_keypoints + position = data + if "keypoints" in data.dims: + if data.sizes["keypoints"] < 1: + raise ValueError( + "data.keypoints must contain at least one keypoint." + ) + position = data.isel(keypoints=0, drop=True) - # Validate keypoints are different - if origin == target: - raise logger.error( - ValueError("The origin and target keypoints may not be identical.") - ) + if "keypoints" in data.coords and data.coords["keypoints"].size > 0: + logger.info( + "Using keypoint '%s' for displacement-based heading.", + data.coords["keypoints"].values[0], + ) + else: + logger.info( + "Using keypoint index 0 for displacement-based heading." + ) + + displacement = position - position.shift(time=displacement_frames) + return displacement + + +def _select_xy(data: xr.DataArray) -> xr.DataArray: + """Select the planar x/y components and return standard dim order.""" + return data.sel(space=["x", "y"]).transpose("time", "space", "individuals") - # Validate keypoints exist + +def _validate_position_data( + data: xr.DataArray, + heading_keypoints: tuple[Hashable, Hashable] | None, +) -> tuple[Hashable, Hashable] | None: + """Validate the input array and normalize ``heading_keypoints``.""" validate_dims_coords( data, - {"keypoints": [origin, target]}, + { + "time": [], + "space": [], + "individuals": [], + }, ) - # Compute heading as vector from origin to target - heading = data.sel(keypoints=target, drop=True) - data.sel( - keypoints=origin, drop=True - ) + allowed_dims = {"time", "space", "individuals", "keypoints"} + unexpected_dims = set(data.dims) - allowed_dims + if unexpected_dims: + raise ValueError( + f"data contains unsupported dimension(s): " + f"{sorted(str(d) for d in unexpected_dims)}" + ) - return heading + if "space" not in data.coords: + raise ValueError( + "data must have coordinate labels for the 'space' dimension." + ) + space_labels = set(data.coords["space"].values.tolist()) + if not {"x", "y"}.issubset(space_labels): + raise ValueError( + "data.space must include coordinate labels 'x' and 'y'." + ) -def _compute_heading_from_velocity( - data: xr.DataArray, - displacement_frames: int = 1, -) -> xr.DataArray: - """Compute heading vectors from displacement direction. + if heading_keypoints is None: + return None - Uses the first available keypoint if multiple are present. + origin, target = _normalize_heading_keypoints(heading_keypoints) - Parameters - ---------- - data : xarray.DataArray - Position data with ``time`` dimension. - displacement_frames : int, optional - Number of frames over which to compute displacement. Default is 1 - (frame-to-frame displacement). Use higher values for smoother - heading estimates (e.g., fps for 1-second displacement). + if "keypoints" not in data.dims: + raise ValueError( + "heading_keypoints requires data to have a 'keypoints' dimension." + ) - Returns - ------- - xarray.DataArray - Heading vectors based on displacement direction. + validate_dims_coords(data, {"keypoints": [origin, target]}) + return origin, target - """ - # If keypoints dimension exists, use first keypoint - if "keypoints" in data.dims: - first_keypoint = data.keypoints.values[0] - position = data.sel(keypoints=first_keypoint, drop=True) - logger.info( - f"Using keypoint '{first_keypoint}' for displacement-based heading." + +def _normalize_heading_keypoints( + heading_keypoints: tuple[Hashable, Hashable] | Any, +) -> tuple[Hashable, Hashable]: + """Validate and normalize the keypoint pair.""" + if isinstance(heading_keypoints, (str, bytes)): + raise TypeError( + "heading_keypoints must be an iterable of exactly two " + "keypoint names." ) - else: - position = data - # Compute displacement over N frames - # displacement[t] = position[t] - position[t - displacement_frames] - displacement = position - position.shift(time=displacement_frames) + try: + origin, target = heading_keypoints + except (TypeError, ValueError) as exc: + raise TypeError( + "heading_keypoints must be an iterable of exactly two " + "keypoint names." + ) from exc - return displacement + for keypoint in (origin, target): + if not isinstance(keypoint, Hashable): + raise TypeError("Each heading keypoint must be hashable.") + if origin == target: + raise ValueError( + "heading_keypoints must contain two distinct keypoint names." + ) -def _validate_type_data_array(data: xr.DataArray) -> None: - """Validate the input data is an xarray DataArray. + return origin, target - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - Raises - ------ - TypeError - If the input data is not an xarray DataArray. +def _validate_displacement_frames(displacement_frames: int) -> None: + """Validate the displacement window.""" + if isinstance(displacement_frames, (bool, np.bool_)) or not isinstance( + displacement_frames, + (int, np.integer), + ): + raise TypeError("displacement_frames must be a positive integer.") - """ + if displacement_frames < 1: + raise ValueError("displacement_frames must be >= 1.") + + +def _validate_type_data_array(data: xr.DataArray) -> None: + """Validate that the input is an xarray.DataArray.""" if not isinstance(data, xr.DataArray): - raise logger.error( - TypeError( - "Input data must be an xarray.DataArray, " - f"but got {type(data)}." - ) + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 2a54e0295..8575a8b26 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -1,3 +1,4 @@ +# test_collective.py """Tests for the collective behavior metrics module.""" import numpy as np @@ -7,1728 +8,781 @@ from movement import kinematics -@pytest.fixture -def position_data_aligned_individuals(): - """Return position data for 3 individuals all moving in the same direction. - - All individuals move along the positive x-axis at every time step, - so polarization should be 1.0. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2"] - keypoints = ["centroid"] - space = ["x", "y"] - - # All individuals move in +x direction - # Shape: (time=4, space=2, keypoints=1, individuals=3) - # x-coords: all increase by 1 each time step - # y-coords: all stay at 0 - data = np.array( - [ - # time 0: x=[0,1,2], y=[0,0,0] - [[[0, 1, 2]], [[0, 0, 0]]], - # time 1: x=[1,2,3], y=[0,0,0] - [[[1, 2, 3]], [[0, 0, 0]]], - # time 2: x=[2,3,4], y=[0,0,0] - [[[2, 3, 4]], [[0, 0, 0]]], - # time 3: x=[3,4,5], y=[0,0,0] - [[[3, 4, 5]], [[0, 0, 0]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -@pytest.fixture -def position_data_opposite_individuals(): - """Return position data for 2 individuals moving in opposite directions. - - One moves in +x, the other in -x, so polarization should be 0.0. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0 moves in +x, id_1 moves in -x - # Shape: (time=4, space=2, keypoints=1, individuals=2) - data = np.array( - [ - # time 0: x=[0,10], y=[0,0] - [[[0, 10]], [[0, 0]]], - # time 1: x=[1,9], y=[0,0] - [[[1, 9]], [[0, 0]]], - # time 2: x=[2,8], y=[0,0] - [[[2, 8]], [[0, 0]]], - # time 3: x=[3,7], y=[0,0] - [[[3, 7]], [[0, 0]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -@pytest.fixture -def position_data_with_keypoints(): - """Return position data with origin/target keypoints for heading. - - Two individuals, both facing the same direction (positive x). - Heading is computed as tail -> nose (origin -> target). - """ - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["nose", "tail"] - space = ["x", "y"] - - # Both individuals facing +x (nose ahead of tail in x) - # Shape: (time=3, space=2, keypoints=2, individuals=2) - # For each individual: nose is at higher x than tail - data = np.array( - [ - # time 0: nose_x=[2,5], nose_y=[0,1], tail_x=[0,3], tail_y=[0,1] - [[[2, 5], [0, 3]], [[0, 1], [0, 1]]], - # time 1 - [[[3, 6], [1, 4]], [[0, 1], [0, 1]]], - # time 2 - [[[4, 7], [2, 5]], [[0, 1], [0, 1]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -@pytest.fixture -def position_data_with_nan(): - """Return position data with NaN values for one individual at one time.""" - time = [0, 1, 2] - individuals = ["id_0", "id_1", "id_2"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Shape: (time=3, space=2, keypoints=1, individuals=3) - # id_1 has NaN at time 1 - data = np.array( - [ - # time 0: x=[0,1,2], y=[0,0,0] - all valid - [[[0, 1, 2]], [[0, 0, 0]]], - # time 1: x=[1,nan,3], y=[0,nan,0] - id_1 is NaN - [[[1, np.nan, 3]], [[0, np.nan, 0]]], - # time 2: x=[2,3,4], y=[0,0,0] - all valid - [[[2, 3, 4]], [[0, 0, 0]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -@pytest.fixture -def position_data_perpendicular(): - """Return position data for 4 individuals in perpendicular dirs. - - Each individual moves in one of the 4 cardinal directions (+x, -x, +y, -y). - The sum of unit vectors is zero, so polarization should be 0.0. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2", "id_3"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0: +x, id_1: -x, id_2: +y, id_3: -y - # Shape: (time=4, space=2, keypoints=1, individuals=4) - data = np.array( - [ - # time 0 - [[[0, 10, 0, 0]], [[0, 0, 0, 10]]], - # time 1: +x moves right, -x moves left, +y moves up, -y moves down - [[[1, 9, 0, 0]], [[0, 0, 1, 9]]], - # time 2 - [[[2, 8, 0, 0]], [[0, 0, 2, 8]]], - # time 3 - [[[3, 7, 0, 0]], [[0, 0, 3, 7]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -@pytest.fixture -def position_data_partial_alignment(): - """Return position data for 3 individuals with partial alignment. - - Two individuals move in +x, one moves in +y. - Expected polarization: |[2,1]|/3 = sqrt(5)/3 ≈ 0.745 - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0 and id_1 move in +x, id_2 moves in +y - # Shape: (time=4, space=2, keypoints=1, individuals=3) - data = np.array( - [ - # time 0 - [[[0, 5, 0]], [[0, 0, 0]]], - # time 1 - [[[1, 6, 0]], [[0, 0, 1]]], - # time 2 - [[[2, 7, 0]], [[0, 0, 2]]], - # time 3 - [[[3, 8, 0]], [[0, 0, 3]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - +def _get_space_labels(n_space: int, space: list[str] | None) -> list[str]: + """Return space labels, defaulting to ['x', 'y'] for 2D.""" + if space is not None: + return space + if n_space == 2: + return ["x", "y"] + raise ValueError("Provide explicit `space` labels for non-2D data.") + + +def _make_position_dataarray( + data: np.ndarray, + *, + time: list | None = None, + individuals: list | None = None, + keypoints: list[str] | None = None, + space: list[str] | None = None, +) -> xr.DataArray: + """Create a position DataArray for tests.""" + data = np.asarray(data, dtype=float) + n_time, n_space = data.shape[0], data.shape[1] + + if data.ndim == 3: + n_individuals = data.shape[2] + ind = individuals or [f"id_{i}" for i in range(n_individuals)] + return xr.DataArray( + data, + dims=["time", "space", "individuals"], + coords={ + "time": time if time else list(range(n_time)), + "space": _get_space_labels(n_space, space), + "individuals": ind, + }, + name="position", + ) -@pytest.fixture -def position_data_single_individual(): - """Return position data for a single individual. - - In this synthetic dataset, polarization is 1.0 whenever a valid heading - can be computed. First-frame behavior in velocity mode depends on boundary - differencing. - """ - time = [0, 1, 2, 3] - individuals = ["id_0"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Single individual moving in +x direction - data = np.array( - [ - [[[0]], [[0]]], - [[[1]], [[0]]], - [[[2]], [[0]]], - [[[3]], [[0]]], - ], - dtype=float, - ) + if data.ndim == 4: + n_keypoints, n_individuals = data.shape[2], data.shape[3] + kp = keypoints or [f"kp_{i}" for i in range(n_keypoints)] + ind = individuals or [f"id_{i}" for i in range(n_individuals)] + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time if time else list(range(n_time)), + "space": _get_space_labels(n_space, space), + "keypoints": kp, + "individuals": ind, + }, + name="position", + ) - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + raise ValueError( + "Expected data with shape (time, space, individuals) or " + "(time, space, keypoints, individuals)." ) @pytest.fixture -def position_data_all_nan_frame(): - """Return position data with one frame where all individuals are NaN.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # All individuals have NaN at time 2 +def aligned_positions() -> xr.DataArray: + """Two individuals moving together in +x direction.""" data = np.array( [ - [[[0, 5]], [[0, 0]]], - [[[1, 6]], [[0, 0]]], - [[[np.nan, np.nan]], [[np.nan, np.nan]]], # all NaN - [[[3, 8]], [[0, 0]]], + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], ], dtype=float, ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) + return _make_position_dataarray(data) @pytest.fixture -def position_data_stationary(): - """Return position data where individuals are stationary.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both individuals stay at the same position +def opposite_positions() -> xr.DataArray: + """Two individuals moving in opposite x directions (+x and -x).""" data = np.array( [ - [[[0, 5]], [[0, 0]]], - [[[0, 5]], [[0, 0]]], - [[[0, 5]], [[0, 0]]], - [[[0, 5]], [[0, 0]]], + [[0, 5], [0, 0]], + [[1, 4], [0, 0]], + [[2, 3], [0, 0]], + [[3, 2], [0, 0]], ], dtype=float, ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) + return _make_position_dataarray(data) @pytest.fixture -def position_data_large_n(): - """Return position data with many individuals (N=50) all aligned.""" - time = [0, 1, 2] - n_individuals = 50 - individuals = [f"id_{i}" for i in range(n_individuals)] - keypoints = ["centroid"] - space = ["x", "y"] - - # All individuals move in +x direction - # Shape: (time=3, space=2, keypoints=1, individuals=50) - x_coords = np.arange(n_individuals, dtype=float) +def partial_alignment_positions() -> xr.DataArray: + """Three individuals: two move +x, one moves +y.""" data = np.array( [ - [[x_coords], [np.zeros(n_individuals)]], - [[x_coords + 1], [np.zeros(n_individuals)]], - [[x_coords + 2], [np.zeros(n_individuals)]], + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], ], dtype=float, ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) + return _make_position_dataarray(data) @pytest.fixture -def position_data_no_keypoints(): - """Return position data without keypoints dimension.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2"] - space = ["x", "y"] - - # Shape: (time=4, space=2, individuals=3) - # All individuals move in +x direction +def perpendicular_positions() -> xr.DataArray: + """Four individuals moving in cardinal directions (+x, -x, +y, -y).""" data = np.array( [ - [[0, 1, 2], [0, 0, 0]], - [[1, 2, 3], [0, 0, 0]], - [[2, 3, 4], [0, 0, 0]], - [[3, 4, 5], [0, 0, 0]], + [[0, 10, 0, 0], [0, 0, 0, 10]], + [[1, 9, 0, 0], [0, 0, 1, 9]], + [[2, 8, 0, 0], [0, 0, 2, 8]], + [[3, 7, 0, 0], [0, 0, 3, 7]], ], dtype=float, ) - - return xr.DataArray( - data, - dims=["time", "space", "individuals"], - coords={ - "time": time, - "space": space, - "individuals": individuals, - }, - ) + return _make_position_dataarray(data) @pytest.fixture -def position_data_multiple_keypoints(): - """Return position data with multiple keypoints for velocity mode test. - - Tests that velocity mode uses the first keypoint. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["nose", "tail", "center"] # nose is first - space = ["x", "y"] - - # nose moves in +x (should be used) - # tail moves in -x - # center moves in +y - # Shape: (time=4, space=2, keypoints=3, individuals=2) +def keypoint_positions() -> xr.DataArray: + """Two individuals with tail/nose keypoints, both facing +x.""" data = np.array( [ - # time 0 - [ - [[0, 0], [10, 10], [0, 0]], # x: nose, tail, center - [[0, 0], [0, 0], [0, 0]], # y: nose, tail, center - ], - # time 1 [ - [[1, 1], [9, 9], [0, 0]], - [[0, 0], [0, 0], [1, 1]], + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], ], - # time 2 [ - [[2, 2], [8, 8], [0, 0]], - [[0, 0], [0, 0], [2, 2]], + [[0.5, 10.5], [1.5, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], ], - # time 3 [ - [[3, 3], [7, 7], [0, 0]], - [[0, 0], [0, 0], [3, 3]], + [[1.0, 11.0], [2.0, 12.0]], + [[0.0, 0.0], [0.0, 0.0]], ], ], dtype=float, ) + return _make_position_dataarray(data, keypoints=["tail", "nose"]) - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) +class TestComputePolarizationValidation: + """Tests for input validation in compute_polarization.""" -@pytest.fixture -def position_data_diagonal_movement(): - """Return position data with diagonal movement at 45 degrees.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both individuals move diagonally (45 degrees, +x +y) - data = np.array( - [ - [[[0, 5]], [[0, 5]]], - [[[1, 6]], [[1, 6]]], - [[[2, 7]], [[2, 7]]], - [[[3, 8]], [[3, 8]]], - ], - dtype=float, - ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - + def test_requires_dataarray(self): + """Raise TypeError if input is not an xarray.DataArray.""" + with pytest.raises(TypeError, match="xarray.DataArray"): + kinematics.compute_polarization(np.zeros((3, 2, 2))) -@pytest.fixture -def position_data_keypoints_opposite(): - """Return position data where keypoint-based headings are opposite.""" - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["nose", "tail"] - space = ["x", "y"] - - # id_0 faces +x (nose ahead of tail) - # id_1 faces -x (nose behind tail) - data = np.array( + @pytest.mark.parametrize( + "dims", [ - # time 0: id_0 nose at (2,0), tail at (0,0) -> faces +x - # id_1 nose at (3,0), tail at (5,0) -> faces -x - [[[2, 3], [0, 5]], [[0, 0], [0, 0]]], - [[[3, 4], [1, 6]], [[0, 0], [0, 0]]], - [[[4, 5], [2, 7]], [[0, 0], [0, 0]]], + ("space", "individuals"), + ("time", "individuals"), + ("time", "space"), ], - dtype=float, + ids=["missing_time", "missing_space", "missing_individuals"], ) + def test_requires_time_space_individuals(self, dims): + """Raise ValueError if required dimensions are missing.""" + data = xr.DataArray(np.zeros((2, 2)), dims=dims) + with pytest.raises(ValueError, match="time|space|individuals"): + kinematics.compute_polarization(data) - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) + def test_rejects_unexpected_dimensions(self): + """Raise ValueError if data contains unsupported dimensions.""" + data = xr.DataArray( + np.zeros((3, 2, 2, 2)), + dims=["time", "space", "individuals", "batch"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "individuals": ["a", "b"], + "batch": [0, 1], + }, + ) + with pytest.raises(ValueError, match="unsupported dimension"): + kinematics.compute_polarization(data) + def test_requires_x_and_y_space_labels(self): + """Raise ValueError if space dimension lacks x and y labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["lat", "lon"], + "individuals": ["a", "b"], + }, + ) + with pytest.raises( + ValueError, match="include coordinate labels 'x' and 'y'" + ): + kinematics.compute_polarization(data) -@pytest.fixture -def position_data_non_uniform_time(): - """Return position data with non-uniform time spacing.""" - time = [0.0, 0.5, 2.0, 5.0] # Non-uniform intervals - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both move in +x direction - data = np.array( + @pytest.mark.parametrize( + "heading_keypoints", [ - [[[0, 5]], [[0, 0]]], - [[[1, 6]], [[0, 0]]], - [[[2, 7]], [[0, 0]]], - [[[3, 8]], [[0, 0]]], + "nose", + ("tail",), + ("tail", "nose", "ear"), + 123, ], - dtype=float, + ids=["string", "length_one", "length_three", "non_iterable"], ) - - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - -class TestComputePolarization: - """Test suite for the compute_polarization function.""" - - def test_polarization_aligned( + def test_heading_keypoints_must_be_length_two_iterable( self, - position_data_aligned_individuals, - position_data_diagonal_movement, + heading_keypoints, + keypoint_positions, ): - """Test polarization is 1.0 when all move same direction. - - Tests both horizontal and diagonal movement to verify that - polarization is rotation-invariant (direction angle doesn't matter, - only alignment between individuals). - """ - # Test horizontal alignment - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - - assert isinstance(polarization, xr.DataArray) - assert polarization.name == "polarization" - assert "time" in polarization.dims - assert "individuals" not in polarization.dims - assert "space" not in polarization.dims - - # All moving in same direction -> polarization should be ~1.0 - # (Skip first time point since velocity is computed via diff) - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - # Test diagonal alignment (rotation invariance) - # Both individuals moving at 45 degrees should also yield pol=1.0 - polarization_diag = kinematics.compute_polarization( - position_data_diagonal_movement - ) - assert np.allclose(polarization_diag.values[1:], 1.0, atol=1e-10) - - def test_polarization_opposite(self, position_data_opposite_individuals): - """Test polarization is 0.0 when individuals move opposite.""" - polarization = kinematics.compute_polarization( - position_data_opposite_individuals - ) - - assert isinstance(polarization, xr.DataArray) - assert polarization.name == "polarization" - - # Opposite directions -> polarization should be ~0.0 - # (Skip first time point since velocity is computed via diff) - assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - - def test_polarization_with_keypoints(self, position_data_with_keypoints): - """Test polarization using keypoint-based heading.""" - polarization = kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("tail", "nose"), # origin -> target - ) - - assert isinstance(polarization, xr.DataArray) - assert polarization.name == "polarization" - - # Both facing same direction -> polarization should be 1.0 - assert np.allclose(polarization.values, 1.0, atol=1e-10) - - def test_polarization_handles_nan(self, position_data_with_nan): - """Test that NaN values are handled correctly.""" - polarization = kinematics.compute_polarization(position_data_with_nan) - - assert isinstance(polarization, xr.DataArray) - # Should compute polarization even with missing data - # The frame with NaN should exclude that individual from calculation - assert not np.all(np.isnan(polarization.values)) - - def test_invalid_input_type(self, position_data_aligned_individuals): - """Test that non-DataArray input raises TypeError.""" - with pytest.raises(TypeError, match="must be an xarray.DataArray"): + """Raise TypeError if heading_keypoints is not length-two.""" + with pytest.raises(TypeError, match="exactly two keypoint names"): kinematics.compute_polarization( - position_data_aligned_individuals.values + keypoint_positions, + heading_keypoints=heading_keypoints, ) - def test_missing_dimensions(self, position_data_aligned_individuals): - """Test that missing required dimensions raises ValueError.""" - # Drop individuals dimension - data_no_individuals = position_data_aligned_individuals.sel( - individuals="id_0", drop=True - ) - with pytest.raises(ValueError, match="individuals"): - kinematics.compute_polarization(data_no_individuals) - - def test_invalid_keypoints(self, position_data_with_keypoints): - """Test that invalid keypoint names raise ValueError.""" - with pytest.raises(ValueError, match="nonexistent"): + def test_heading_keypoints_must_be_hashable(self, keypoint_positions): + """Raise TypeError if heading keypoints are not hashable.""" + with pytest.raises(TypeError, match="hashable"): kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("nose", "nonexistent"), + keypoint_positions, + heading_keypoints=(["tail"], "nose"), ) - def test_identical_keypoints(self, position_data_with_keypoints): - """Test that identical origin and target keypoints raise ValueError.""" - with pytest.raises(ValueError, match="may not be identical"): + def test_heading_keypoints_require_keypoints_dimension( + self, aligned_positions + ): + """Raise ValueError if heading_keypoints given but no keypoints dim.""" + with pytest.raises( + ValueError, match="requires data to have a 'keypoints' dimension" + ): kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("nose", "nose"), + aligned_positions, + heading_keypoints=("tail", "nose"), ) - # ================= Intermediate Polarization Values ================= - - def test_polarization_perpendicular_four_directions( - self, position_data_perpendicular - ): - """Test polarization is 0 when 4 individuals move in cardinal dirs.""" - polarization = kinematics.compute_polarization( - position_data_perpendicular - ) - - # 4 perpendicular directions cancel out -> polarization = 0.0 - # Compare frames 1: avoid boundary differencing dependence at t=0. - assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - - def test_polarization_partial_alignment( - self, position_data_partial_alignment - ): - """Test intermediate polarization with partial alignment. - - Two individuals move +x, one moves +y. - Unit vectors: [1,0], [1,0], [0,1] - Sum: [2, 1] - Magnitude: sqrt(5) - Polarization: sqrt(5)/3 ≈ 0.745 - """ - polarization = kinematics.compute_polarization( - position_data_partial_alignment - ) - - expected = np.sqrt(5) / 3 # ≈ 0.745 - # Compare frames 1: avoid boundary differencing dependence at t=0. - assert np.allclose(polarization.values[1:], expected, atol=1e-10) - - # ==================== Edge Cases ==================== - - def test_polarization_single_individual( - self, position_data_single_individual - ): - """Test polarization is 1.0 for a single individual.""" - polarization = kinematics.compute_polarization( - position_data_single_individual - ) - - # Single individual always has polarization = 1.0 - # (the unit vector divided by 1) - # Compare frames 1: avoid boundary differencing dependence at t=0. - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - def test_polarization_all_nan_frame(self, position_data_all_nan_frame): - """Test that all-NaN frames result in NaN polarization.""" - polarization = kinematics.compute_polarization( - position_data_all_nan_frame - ) - - # Frame at index 2 has all NaN positions - # Velocity at frame 2 is computed from frames 1->2 and 2->3 - # Due to NaN at frame 2, velocity at frames 2 and 3 will be affected - # The exact behavior depends on compute_velocity's edge handling - assert isinstance(polarization, xr.DataArray) - # At minimum, verify we get a result with correct length - assert len(polarization) == len(position_data_all_nan_frame.time) - - def test_polarization_stationary(self, position_data_stationary): - """Test that stationary individuals (zero velocity) produce NaN.""" - polarization = kinematics.compute_polarization( - position_data_stationary - ) - - # Zero velocity means zero-length vector -> unit vector is NaN - # All frames after first should be NaN (zero displacement) - assert np.all(np.isnan(polarization.values[1:])) - - def test_polarization_large_n(self, position_data_large_n): - """Test polarization with many individuals (N=50) all aligned.""" - polarization = kinematics.compute_polarization(position_data_large_n) - - # All 50 individuals moving same direction -> polarization = 1.0 - # Compare frames 1: avoid boundary differencing dependence at t=0. - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - # ==================== Data Structure Variations ==================== - - def test_polarization_no_keypoints_dimension( - self, position_data_no_keypoints - ): - """Test polarization works without keypoints dimension.""" - polarization = kinematics.compute_polarization( - position_data_no_keypoints - ) - - assert isinstance(polarization, xr.DataArray) - assert polarization.name == "polarization" - # All moving same direction -> polarization = 1.0 - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - def test_polarization_velocity_mode_uses_first_keypoint( - self, position_data_multiple_keypoints - ): - """Test velocity mode uses first keypoint when multiple exist.""" - polarization = kinematics.compute_polarization( - position_data_multiple_keypoints - ) - - # First keypoint (nose) moves in +x for both individuals - # So polarization should be 1.0 (not affected by tail or center) - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - def test_polarization_keypoints_opposite_directions( - self, position_data_keypoints_opposite - ): - """Test keypoint polarization with opposite facing individuals.""" - polarization = kinematics.compute_polarization( - position_data_keypoints_opposite, - heading_keypoints=("tail", "nose"), - ) - - # id_0 faces +x, id_1 faces -x -> polarization = 0.0 - assert np.allclose(polarization.values, 0.0, atol=1e-10) - - # ==================== Time Coordinate Handling ==================== - - def test_polarization_preserves_time_coords( - self, position_data_aligned_individuals - ): - """Test that output preserves time coordinates from input.""" - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - - np.testing.assert_array_equal( - polarization.time.values, - position_data_aligned_individuals.time.values, - ) - - def test_polarization_non_uniform_time( - self, position_data_non_uniform_time - ): - """Test polarization with non-uniform time spacing.""" - polarization = kinematics.compute_polarization( - position_data_non_uniform_time - ) - - # Should still work and preserve time coords - expected_times = [0.0, 0.5, 2.0, 5.0] - np.testing.assert_array_equal(polarization.time.values, expected_times) - # Both moving same direction - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - - # ==================== Output Properties ==================== - - def test_polarization_output_structure( - self, position_data_aligned_individuals - ): - """Test that output has correct structure (time dimension only). - - Verifies both positive assertion (dims == time) and explicit - absence of input dimensions that should be reduced over. - """ - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - - # Positive assertion: output has exactly time dimension - assert polarization.dims == ("time",) - assert len(polarization) == len(position_data_aligned_individuals.time) - - # Explicit absence checks (documents which dims are reduced) - assert "keypoints" not in polarization.dims - assert "space" not in polarization.dims - assert "individuals" not in polarization.dims - - def test_polarization_first_frame_displacement_mode( - self, position_data_aligned_individuals - ): - """Test first frame behavior when using displacement-based heading. - - Displacement-based heading requires position at frame t and t-N - (where N=displacement_frames). Frame 0 has no prior reference, - so it is NaN. This is consistent and predictable behavior. - """ - polarization = kinematics.compute_polarization( - position_data_aligned_individuals - ) - - assert isinstance(polarization, xr.DataArray) - assert len(polarization) == len(position_data_aligned_individuals.time) - # First frame is NaN (no t-1 reference for displacement) - assert np.isnan(polarization.values[0]) - # Subsequent frames are valid - assert not np.any(np.isnan(polarization.values[1:])) - - def test_polarization_first_frame_valid_keypoint_mode( - self, position_data_with_keypoints - ): - """Test that first frame is valid when using keypoint-based heading.""" - polarization = kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("tail", "nose"), - ) - - # First frame should be valid (keypoint positions are always known) - assert not np.isnan(polarization.values[0]) - - # ==================== Mathematical Properties ==================== - - def test_polarization_symmetry(self): - """Test polarization is symmetric (individual order irrelevant).""" - time = [0, 1, 2] - keypoints = ["centroid"] - space = ["x", "y"] - - # Create two datasets with same individuals in different order - data1 = np.array( - [ - [[[0, 5]], [[0, 0]]], - [[[1, 4]], [[0, 0]]], # id_0: +x, id_1: -x - [[[2, 3]], [[0, 0]]], - ], - dtype=float, - ) - data2 = np.array( - [ - [[[5, 0]], [[0, 0]]], - [[[4, 1]], [[0, 0]]], # id_0: -x, id_1: +x (swapped) - [[[3, 2]], [[0, 0]]], - ], - dtype=float, - ) - - da1 = xr.DataArray( - data1, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": ["id_0", "id_1"], - }, - ) - da2 = xr.DataArray( - data2, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": ["id_0", "id_1"], - }, - ) - - pol1 = kinematics.compute_polarization(da1) - pol2 = kinematics.compute_polarization(da2) - - np.testing.assert_array_almost_equal(pol1.values, pol2.values) - - def test_polarization_bounds(self, position_data_aligned_individuals): - """Test polarization values are always in [0, 1] range. - - Verifies bounds with both: - 1. Simple aligned data (deterministic, yields boundary value 1.0) - 2. Random directions (stochastic, yields distribution across range) - """ - # Test with simple aligned data (boundary case: all 1.0) - polarization_simple = kinematics.compute_polarization( - position_data_aligned_individuals - ) - valid_simple = polarization_simple.values[ - ~np.isnan(polarization_simple.values) - ] - assert np.all(valid_simple >= 0.0) - assert np.all(valid_simple <= 1.0) - - # Test with random directions (interior values) - time = [0, 1, 2, 3, 4] - individuals = [f"id_{i}" for i in range(10)] - keypoints = ["centroid"] - space = ["x", "y"] - - # Create semi-random movement patterns - np.random.seed(42) - n_ind = len(individuals) - n_time = len(time) - - # Random starting positions - x_start = np.random.rand(n_ind) * 100 - y_start = np.random.rand(n_ind) * 100 - - # Random velocities - vx = np.random.randn(n_ind) * 2 - vy = np.random.randn(n_ind) * 2 - - data = np.zeros((n_time, 2, 1, n_ind)) - for t in range(n_time): - data[t, 0, 0, :] = x_start + vx * t - data[t, 1, 0, :] = y_start + vy * t - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - polarization_random = kinematics.compute_polarization(da) - - valid_random = polarization_random.values[ - ~np.isnan(polarization_random.values) - ] - assert np.all(valid_random >= 0.0) - assert np.all(valid_random <= 1.0) - - # ==================== NaN Handling Edge Cases ==================== - - def test_polarization_nan_one_coordinate_only(self): - """Test handling when only one spatial coordinate is NaN.""" - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_1 has NaN only in x at time 1 - data = np.array( - [ - [[[0, 5]], [[0, 0]]], - [[[1, np.nan]], [[0, 0]]], # x is NaN, y is valid - [[[2, 7]], [[0, 0]]], - ], - dtype=float, - ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - polarization = kinematics.compute_polarization(da) - - # Should still compute - NaN individual is excluded - assert isinstance(polarization, xr.DataArray) - - def test_polarization_nan_in_keypoints(self): - """Test handling NaN in keypoint-based heading calculation.""" - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["nose", "tail"] - space = ["x", "y"] - - # id_1's nose is NaN at time 1 - data = np.array( - [ - [[[2, 5], [0, 3]], [[0, 0], [0, 0]]], - [[[3, np.nan], [1, 4]], [[0, np.nan], [0, 0]]], - [[[4, 7], [2, 5]], [[0, 0], [0, 0]]], - ], - dtype=float, - ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - polarization = kinematics.compute_polarization( - da, heading_keypoints=("tail", "nose") - ) - - # Time 0 and 2 should be valid (both face +x -> 1.0) - assert np.allclose(polarization.values[0], 1.0, atol=1e-10) - assert np.allclose(polarization.values[2], 1.0, atol=1e-10) - # Time 1 should be 1.0 (only id_0 is valid, single individual) - assert np.allclose(polarization.values[1], 1.0, atol=1e-10) - - # ==================== Error Handling ==================== - - def test_missing_space_dimension(self): - """Test that missing space dimension raises ValueError.""" - data = xr.DataArray( - np.random.rand(4, 3), - dims=["time", "individuals"], - coords={"time": [0, 1, 2, 3], "individuals": ["a", "b", "c"]}, - ) - with pytest.raises(ValueError, match="space"): - kinematics.compute_polarization(data) - - def test_missing_time_dimension(self): - """Test that missing time dimension raises ValueError.""" - data = xr.DataArray( - np.random.rand(2, 3), - dims=["space", "individuals"], - coords={"space": ["x", "y"], "individuals": ["a", "b", "c"]}, - ) - with pytest.raises(ValueError, match="time"): - kinematics.compute_polarization(data) - - def test_empty_dataarray(self): - """Test handling of empty DataArray returns empty result. - - Empty arrays are handled gracefully by the displacement-based - implementation, returning an empty polarization array. - """ - data = xr.DataArray( - np.array([]).reshape(0, 2, 0), - dims=["time", "space", "individuals"], - coords={"time": [], "space": ["x", "y"], "individuals": []}, - ) - polarization = kinematics.compute_polarization(data) - - assert isinstance(polarization, xr.DataArray) - assert len(polarization) == 0 - - def test_polarization_mixed_stationary_moving(self): - """Test polarization with some stationary and some moving individuals. - - When some individuals are stationary (zero velocity), they should be - excluded from the polarization calculation (NaN unit vector). - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0: stationary at (0, 0) - # id_1: moves in +x direction - # id_2: moves in +x direction - data = np.array( - [ - [[[0, 5, 10]], [[0, 0, 0]]], - [[[0, 6, 11]], [[0, 0, 0]]], - [[[0, 7, 12]], [[0, 0, 0]]], - [[[0, 8, 13]], [[0, 0, 0]]], - ], - dtype=float, - ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - polarization = kinematics.compute_polarization(da) - - # id_0 is stationary -> NaN heading -> excluded - # id_1 and id_2 both move +x -> polarization = 1.0 - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + def test_heading_keypoints_must_exist(self, keypoint_positions): + """Raise ValueError if specified keypoints do not exist in data.""" + with pytest.raises(ValueError, match="snout|keypoints"): + kinematics.compute_polarization( + keypoint_positions, + heading_keypoints=("tail", "snout"), + ) - def test_polarization_two_time_points_minimum(self): - """Test polarization with minimum time points (2). + def test_heading_keypoints_must_be_distinct(self, keypoint_positions): + """Raise ValueError if origin and target keypoints are identical.""" + with pytest.raises(ValueError, match="two distinct keypoint names"): + kinematics.compute_polarization( + keypoint_positions, + heading_keypoints=("tail", "tail"), + ) - With displacement-based heading, frame 0 is NaN (no prior reference). - Frame 1 has valid displacement from frame 0 to frame 1. - """ - time = [0, 1] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] + @pytest.mark.parametrize( + "displacement_frames,expected_exception", + [ + (0, ValueError), + (-1, ValueError), + (1.5, TypeError), + (True, TypeError), + ], + ids=["zero", "negative", "float", "bool"], + ) + def test_displacement_frames_must_be_positive_integer( + self, + aligned_positions, + displacement_frames, + expected_exception, + ): + """Raise error if displacement_frames is not a positive integer.""" + with pytest.raises(expected_exception, match="positive integer|>= 1"): + kinematics.compute_polarization( + aligned_positions, + displacement_frames=displacement_frames, + ) - # Both move in +x direction - data = np.array( - [ - [[[0, 5]], [[0, 0]]], - [[[1, 6]], [[0, 0]]], - ], - dtype=float, + def test_invalid_displacement_frames_is_ignored_in_keypoint_mode( + self, + keypoint_positions, + ): + """Invalid displacement_frames is ignored when keypoints are used.""" + polarization = kinematics.compute_polarization( + keypoint_positions, + heading_keypoints=("tail", "nose"), + displacement_frames=0, ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - polarization = kinematics.compute_polarization(da) +class TestComputePolarizationBehavior: + """Tests for polarization computation behavior.""" - assert isinstance(polarization, xr.DataArray) - assert len(polarization) == 2 - # Frame 0 is NaN (no prior reference for displacement) + def test_aligned_motion_gives_one(self, aligned_positions): + """Polarization is 1.0 when all individuals move in same direction.""" + polarization = kinematics.compute_polarization(aligned_positions) assert np.isnan(polarization.values[0]) - # Frame 1 is valid: both moving same direction -> polarization = 1.0 - assert np.allclose(polarization.values[1], 1.0, atol=1e-10) - - # ==================== 3D Data (Limitation) ==================== + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - def test_3d_data_uses_only_xy(self): - """Test that 3D spatial data only uses x,y coordinates. + def test_opposite_motion_gives_zero(self, opposite_positions): + """Polarization is 0.0 when individuals move in opposite directions.""" + polarization = kinematics.compute_polarization(opposite_positions) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - LIMITATION: The current implementation uses validate_dims_coords - with exact_coords=False, so 3D data passes validation but only - x and y coordinates are used for norm/unit vector computation. - The z coordinate is silently ignored. + def test_perpendicular_cardinal_directions_give_zero( + self, perpendicular_positions + ): + """Polarization is 0.0 when four individuals move in cardinal dirs.""" + polarization = kinematics.compute_polarization(perpendicular_positions) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - This test documents this limitation. - """ - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y", "z"] + def test_partial_alignment_matches_expected_magnitude( + self, + partial_alignment_positions, + ): + """Polarization matches expected value for partial alignment.""" + polarization = kinematics.compute_polarization( + partial_alignment_positions + ) + expected = np.sqrt(5) / 3 + assert np.allclose(polarization.values[1:], expected, atol=1e-10) - # 3D movement data - both individuals move in +x direction - # z movement is present but will be ignored + def test_single_individual_gives_one(self): + """Polarization is 1.0 for a single moving individual.""" data = np.array( [ - [[[0, 5]], [[0, 0]], [[0, 0]]], - [[[1, 6]], [[0, 0]], [[1, 1]]], - [[[2, 7]], [[0, 0]], [[2, 2]]], + [[0], [0]], + [[1], [0]], + [[2], [0]], + [[3], [0]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + polarization = kinematics.compute_polarization( + _make_position_dataarray(data) ) - - # 3D data silently produces results using only x,y - # This is a limitation - z is ignored - polarization = kinematics.compute_polarization(da) - assert isinstance(polarization, xr.DataArray) - # Both moving in same x direction -> polarization = 1.0 (ignoring z) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - # ==================== displacement_frames Parameter ==================== - - def test_displacement_frames_default( - self, position_data_aligned_individuals - ): - """Test displacement_frames=1 (default) matches original behavior.""" - # Explicit default should match implicit default - pol_implicit = kinematics.compute_polarization( - position_data_aligned_individuals - ) - pol_explicit = kinematics.compute_polarization( - position_data_aligned_individuals, - displacement_frames=1, - ) - - np.testing.assert_array_equal(pol_implicit.values, pol_explicit.values) - - def test_displacement_frames_multi_frame(self): - """Test multi-frame displacement computes heading over N frames. - - With displacement_frames=2, heading is computed from position[t-2] - to position[t], giving smoother estimates over longer intervals. - """ - time = [0, 1, 2, 3, 4, 5] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both move in +x direction consistently + def test_large_n_aligned_gives_one(self): + """Polarization is 1.0 for 50 aligned individuals.""" + n_individuals = 50 + x_coords = np.arange(n_individuals, dtype=float) data = np.array( [ - [[[0, 10]], [[0, 0]]], - [[[1, 11]], [[0, 0]]], - [[[2, 12]], [[0, 0]]], - [[[3, 13]], [[0, 0]]], - [[[4, 14]], [[0, 0]]], - [[[5, 15]], [[0, 0]]], + [x_coords, np.zeros(n_individuals)], + [x_coords + 1, np.zeros(n_individuals)], + [x_coords + 2, np.zeros(n_individuals)], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - polarization = kinematics.compute_polarization( - da, displacement_frames=2 + _make_position_dataarray(data) ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - # First 2 frames are NaN (no t-2 reference) - assert np.isnan(polarization.values[0]) - assert np.isnan(polarization.values[1]) - # Frames 2+ should have valid polarization = 1.0 - assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) - - def test_displacement_frames_larger_than_timeseries(self): - """Test displacement_frames larger than available frames yields NaN.""" - time = [0, 1, 2] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - + def test_stationary_individuals_are_excluded(self): + """Stationary individuals produce NaN polarization and angle.""" data = np.array( [ - [[[0, 5]], [[0, 0]]], - [[[1, 6]], [[0, 0]]], - [[[2, 7]], [[0, 0]]], + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - # displacement_frames=10 but only 3 time points - polarization = kinematics.compute_polarization( - da, displacement_frames=10 + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, ) - - # All frames should be NaN (no valid t-10 reference) assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) - def test_displacement_frames_ignored_with_keypoints( - self, position_data_with_keypoints - ): - """Test displacement_frames is ignored when heading_keypoints given.""" - pol_default = kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("tail", "nose"), - displacement_frames=1, + def test_stationary_and_moving_individuals_uses_only_valid_headings(self): + """Only moving individuals contribute to polarization.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[1, 10], [0, 0]], + [[2, 10], [0, 0]], + [[3, 10], [0, 0]], + ], + dtype=float, ) - pol_large = kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("tail", "nose"), - displacement_frames=100, # Should be ignored + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) - # Both should be identical since keypoints override displacement - np.testing.assert_array_equal(pol_default.values, pol_large.values) - # All frames valid (keypoint heading doesn't need displacement) - assert np.allclose(pol_default.values, 1.0, atol=1e-10) - - def test_displacement_frames_with_nan_at_reference(self): - """Test NaN at reference frame (t - displacement_frames) propagates.""" - time = [0, 1, 2, 3, 4] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # NaN at frame 1 + def test_one_coordinate_nan_excludes_that_individual(self): + """NaN in one coordinate excludes that individual from calculation.""" data = np.array( [ - [[[0, 5]], [[0, 0]]], - [[[np.nan, np.nan]], [[np.nan, np.nan]]], # NaN frame - [[[2, 7]], [[0, 0]]], - [[[3, 8]], [[0, 0]]], - [[[4, 9]], [[0, 0]]], + [[0, 10], [0, 0]], + [[1, np.nan], [0, 0]], + [[2, 12], [0, 0]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, ) + assert np.isnan(polarization.values[0]) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + def test_nan_in_keypoint_heading_excludes_that_individual(self): + """NaN in keypoint position excludes that individual.""" + data = np.array( + [ + [ + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[1.0, 10.0], [2.0, np.nan]], + [[0.0, 0.0], [0.0, np.nan]], + ], + [ + [[2.0, 12.0], [3.0, 13.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail", "nose"]) polarization = kinematics.compute_polarization( - da, displacement_frames=2 + da, + heading_keypoints=("tail", "nose"), ) + assert np.allclose(polarization.values[[0, 2]], 1.0, atol=1e-10) + assert np.allclose(polarization.values[1], 1.0, atol=1e-10) - # Frame 3 references frame 1 (NaN) -> should be NaN - assert np.isnan(polarization.values[3]) - # Frame 4 references frame 2 (valid) -> should be valid - assert not np.isnan(polarization.values[4]) - - def test_displacement_frames_smooths_noisy_trajectory(self): - """Test larger displacement_frames smooths jittery movement.""" - time = list(range(10)) - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Noisy +x movement with small y jitter - np.random.seed(42) - jitter = np.random.randn(10, 2) * 0.1 - - data = np.zeros((10, 2, 1, 2)) - for t in range(10): - # Both individuals move +x with small y noise - data[t, 0, 0, :] = [t + jitter[t, 0], t + jitter[t, 1]] - data[t, 1, 0, :] = [jitter[t, 0], jitter[t, 1]] - - da = xr.DataArray( + def test_empty_individual_axis_returns_all_nan(self): + """Empty individuals axis returns all NaN values.""" + data = _make_position_dataarray( + np.empty((3, 2, 0)), + individuals=[], + space=["x", "y"], + ) + polarization, mean_angle = kinematics.compute_polarization( data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + return_angle=True, ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) - pol_1frame = kinematics.compute_polarization(da, displacement_frames=1) - pol_5frame = kinematics.compute_polarization(da, displacement_frames=5) - - # 5-frame displacement should yield higher polarization (less noise) - valid_1 = pol_1frame.values[~np.isnan(pol_1frame.values)] - valid_5 = pol_5frame.values[~np.isnan(pol_5frame.values)] - - # Mean polarization should be higher with smoothing - assert np.mean(valid_5) >= np.mean(valid_1) - 0.1 # Allow small margin - - # ==================== return_angle Parameter ==================== - - def test_return_angle_false_returns_dataarray( - self, position_data_aligned_individuals - ): - """Test return_angle=False returns single DataArray (default).""" - result = kinematics.compute_polarization( - position_data_aligned_individuals, - return_angle=False, + def test_empty_time_axis_returns_empty_outputs(self): + """Empty time axis returns empty output arrays.""" + data = xr.DataArray( + np.empty((0, 2, 0)), + dims=["time", "space", "individuals"], + coords={"time": [], "space": ["x", "y"], "individuals": []}, + name="position", ) - - assert isinstance(result, xr.DataArray) - assert result.name == "polarization" - - def test_return_angle_true_returns_tuple( - self, position_data_aligned_individuals - ): - """Test return_angle=True returns tuple of (polarization, angle).""" - result = kinematics.compute_polarization( - position_data_aligned_individuals, + polarization, mean_angle = kinematics.compute_polarization( + data, return_angle=True, ) - - assert isinstance(result, tuple) - assert len(result) == 2 - - polarization, mean_angle = result - assert isinstance(polarization, xr.DataArray) - assert isinstance(mean_angle, xr.DataArray) + assert polarization.shape == (0,) + assert mean_angle.shape == (0,) assert polarization.name == "polarization" assert mean_angle.name == "mean_angle" - def test_return_angle_positive_x_direction(self): - """Test mean_angle is 0 radians for +x movement.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] + def test_preserves_non_uniform_time_coordinates(self, aligned_positions): + """Non-uniform time coordinates are preserved in output.""" + time = [0.0, 0.25, 0.75, 1.5] + data = aligned_positions.assign_coords(time=time) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + np.testing.assert_array_equal(polarization.time.values, time) + np.testing.assert_array_equal(mean_angle.time.values, time) - # Both move in +x direction + def test_polarization_is_invariant_to_individual_order(self): + """Polarization is independent of individual ordering.""" data = np.array( [ - [[[0, 5]], [[0, 0]]], - [[[1, 6]], [[0, 0]]], - [[[2, 7]], [[0, 0]]], - [[[3, 8]], [[0, 0]]], + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], ], dtype=float, ) + da = _make_position_dataarray(data) + da_permuted = da.isel(individuals=[2, 0, 1]) - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + pol_original = kinematics.compute_polarization(da) + pol_permuted = kinematics.compute_polarization(da_permuted) + + np.testing.assert_allclose( + pol_original.values, pol_permuted.values, atol=1e-10 ) - _, mean_angle = kinematics.compute_polarization(da, return_angle=True) - # +x direction = 0 radians - assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) +class TestHeadingSourceSelection: + """Tests for heading computation mode selection.""" - def test_return_angle_positive_y_direction(self): - """Test mean_angle is π/2 radians for +y movement.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] + def test_keypoint_heading_is_valid_on_first_frame( + self, keypoint_positions + ): + """Keypoint-based heading produces valid values on first frame.""" + polarization, mean_angle = kinematics.compute_polarization( + keypoint_positions, + heading_keypoints=("tail", "nose"), + return_angle=True, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) - # Both move in +y direction + def test_displacement_mode_with_keypoints_uses_first_keypoint(self): + """Displacement mode uses first keypoint when multiple exist.""" data = np.array( [ - [[[0, 0]], [[0, 5]]], - [[[0, 0]], [[1, 6]]], - [[[0, 0]], [[2, 7]]], - [[[0, 0]], [[3, 8]]], + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + [ + [[2, 12], [2, 8]], + [[0, 0], [0, 0]], + ], ], dtype=float, ) + da = _make_position_dataarray(data, keypoints=["centroid", "nose"]) + polarization = kinematics.compute_polarization(da) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - - _, mean_angle = kinematics.compute_polarization(da, return_angle=True) - - # +y direction = π/2 radians - assert np.allclose(mean_angle.values[1:], np.pi / 2, atol=1e-10) - - def test_return_angle_negative_x_direction(self): - """Test mean_angle is ±π radians for -x movement.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both move in -x direction + def test_keypoint_heading_overrides_displacement_behavior(self): + """Keypoint-based heading overrides displacement computation.""" data = np.array( [ - [[[10, 15]], [[0, 0]]], - [[[9, 14]], [[0, 0]]], - [[[8, 13]], [[0, 0]]], - [[[7, 12]], [[0, 0]]], + [ + [[0.0, 0.0], [1.0, 1.0]], + [[0.0, 2.0], [0.0, 2.0]], + ], + [ + [[0.0, 0.0], [1.0, 1.0]], + [[1.0, 3.0], [1.0, 3.0]], + ], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + da = _make_position_dataarray(data, keypoints=["tail", "nose"]) + polarization = kinematics.compute_polarization( + da, + heading_keypoints=("tail", "nose"), + displacement_frames=1000, ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) - _, mean_angle = kinematics.compute_polarization(da, return_angle=True) - - # -x direction = ±π radians (arctan2 returns π or -π) - assert np.allclose(np.abs(mean_angle.values[1:]), np.pi, atol=1e-10) - - def test_return_angle_diagonal_45_degrees(self): - """Test mean_angle is π/4 radians for diagonal (+x, +y) movement.""" - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # Both move diagonally (+x, +y) at 45 degrees + def test_extra_spatial_dimensions_are_ignored_for_planar_metrics(self): + """Extra spatial dimensions (z) are ignored; only x/y used.""" data = np.array( [ - [[[0, 5]], [[0, 5]]], - [[[1, 6]], [[1, 6]]], - [[[2, 7]], [[2, 7]]], - [[[3, 8]], [[3, 8]]], + [[0, 5], [0, 0], [0, 100]], + [[1, 6], [0, 0], [10, -100]], + [[2, 7], [0, 0], [-10, 50]], + [[3, 8], [0, 0], [999, -999]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + da = _make_position_dataarray(data, space=["x", "y", "z"]) + polarization, mean_angle = kinematics.compute_polarization( + da, + return_angle=True, ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) - _, mean_angle = kinematics.compute_polarization(da, return_angle=True) - # 45 degrees = π/4 radians - assert np.allclose(mean_angle.values[1:], np.pi / 4, atol=1e-10) +class TestDisplacementFrames: + """Tests for displacement_frames parameter behavior.""" - def test_return_angle_nan_when_no_valid_individuals( - self, position_data_all_nan_frame - ): - """Test mean_angle is NaN when no valid individuals.""" - _, mean_angle = kinematics.compute_polarization( - position_data_all_nan_frame, + def test_first_n_frames_are_nan(self, aligned_positions): + """First N frames are NaN when displacement_frames=N.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=2, return_angle=True, ) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + assert np.isnan(mean_angle.values[0]) + assert np.isnan(mean_angle.values[1]) + assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[2:], 0.0, atol=1e-10) - # Frame 2 has all NaN -> angle should be NaN for affected frames - assert isinstance(mean_angle, xr.DataArray) - # At least some frames should be NaN due to the all-NaN frame - assert np.any(np.isnan(mean_angle.values)) - - def test_return_angle_with_keypoint_heading( - self, position_data_with_keypoints + def test_nan_in_reference_frame_propagates_to_that_displacement_window( + self, ): - """Test return_angle works with keypoint-based heading.""" - pol, angle = kinematics.compute_polarization( - position_data_with_keypoints, - heading_keypoints=("tail", "nose"), - return_angle=True, - ) - - # Both face +x direction (nose ahead of tail in x) - assert np.allclose(pol.values, 1.0, atol=1e-10) - assert np.allclose(angle.values, 0.0, atol=1e-10) - - def test_return_angle_partial_alignment_mean_direction(self): - """Test mean_angle reflects weighted mean of individual headings. - - Two individuals move +x (angle=0), one moves +y (angle=π/2). - Mean unit vector: [2, 1] / 3, but angle is arctan2(1, 2) ≈ 0.464 rad. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1", "id_2"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0 and id_1 move +x, id_2 moves +y + """NaN in reference frame propagates through displacement window.""" data = np.array( [ - [[[0, 5, 0]], [[0, 0, 0]]], - [[[1, 6, 0]], [[0, 0, 1]]], - [[[2, 7, 0]], [[0, 0, 2]]], - [[[3, 8, 0]], [[0, 0, 3]]], + [[0, 5], [0, 0]], + [[np.nan, np.nan], [np.nan, np.nan]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], + [[4, 9], [0, 0]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + polarization = kinematics.compute_polarization( + _make_position_dataarray(data), + displacement_frames=2, ) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + assert np.allclose(polarization.values[2], 1.0, atol=1e-10) + assert np.isnan(polarization.values[3]) + assert np.allclose(polarization.values[4], 1.0, atol=1e-10) - pol, angle = kinematics.compute_polarization(da, return_angle=True) - - # Sum of unit vectors: [1,0] + [1,0] + [0,1] = [2, 1] - # Mean angle: arctan2(1, 2) ≈ 0.4636 rad ≈ 26.57 degrees - expected_angle = np.arctan2(1, 2) - assert np.allclose(angle.values[1:], expected_angle, atol=1e-10) - - def test_return_angle_opposite_directions_undefined(self): - """Test mean_angle when individuals move in opposite directions. - - When vectors cancel out (polarization ≈ 0), angle is still computed - from the sum vector but may be arbitrary for perfectly opposed motion. - """ - time = [0, 1, 2, 3] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] - - # id_0 moves +x, id_1 moves -x (exactly opposite) + def test_larger_displacement_window_can_change_alignment_estimate(self): + """Larger displacement window smooths jittery movement.""" data = np.array( [ - [[[0, 10]], [[0, 0]]], - [[[1, 9]], [[0, 0]]], - [[[2, 8]], [[0, 0]]], - [[[3, 7]], [[0, 0]]], + [[0, 10], [0, 0]], + [[2, 9], [0, 0]], + [[1, 11], [0, 0]], + [[3, 10], [0, 0]], + [[2, 12], [0, 0]], + [[4, 11], [0, 0]], ], dtype=float, ) + da = _make_position_dataarray(data) + + pol_1frame = kinematics.compute_polarization(da, displacement_frames=1) + pol_2frame = kinematics.compute_polarization(da, displacement_frames=2) + + assert np.allclose(pol_1frame.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_2frame.values[2:], 1.0, atol=1e-10) - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, - ) - pol, angle = kinematics.compute_polarization(da, return_angle=True) +class TestReturnAngle: + """Tests for return_angle parameter behavior.""" - # Polarization should be 0 (opposite directions cancel) - assert np.allclose(pol.values[1:], 0.0, atol=1e-10) - # Angle is computed from [0, 0] sum vector - # arctan2(0, 0) = 0 in numpy - assert np.allclose(angle.values[1:], 0.0, atol=1e-10) + def test_default_returns_only_polarization(self, aligned_positions): + """Default return is a single polarization DataArray.""" + result = kinematics.compute_polarization(aligned_positions) + assert isinstance(result, xr.DataArray) + assert result.name == "polarization" + assert result.dims == ("time",) + + def test_return_angle_true_returns_named_pair(self, aligned_positions): + """return_angle=True returns (polarization, mean_angle) tuple.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + return_angle=True, + ) + assert isinstance(polarization, xr.DataArray) + assert isinstance(mean_angle, xr.DataArray) + assert polarization.name == "polarization" + assert mean_angle.name == "mean_angle" + assert polarization.dims == ("time",) + assert mean_angle.dims == ("time",) - def test_return_angle_with_displacement_frames(self): - """Test return_angle works correctly with multi-frame displacement.""" - time = [0, 1, 2, 3, 4, 5] - individuals = ["id_0", "id_1"] - keypoints = ["centroid"] - space = ["x", "y"] + @pytest.mark.parametrize( + "data,expected_angle,use_abs", + [ + ( + np.array( + [ + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + ], + dtype=float, + ), + 0.0, + False, + ), + ( + np.array( + [ + [[0, 0], [0, 5]], + [[0, 0], [1, 6]], + [[0, 0], [2, 7]], + ], + dtype=float, + ), + np.pi / 2, + False, + ), + ( + np.array( + [ + [[10, 15], [0, 0]], + [[9, 14], [0, 0]], + [[8, 13], [0, 0]], + ], + dtype=float, + ), + np.pi, + True, + ), + ], + ids=["positive_x", "positive_y", "negative_x"], + ) + def test_mean_angle_matches_cardinal_directions( + self, + data, + expected_angle, + use_abs, + ): + """Mean angle matches expected value for cardinal directions.""" + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + values = mean_angle.values[1:] + if use_abs: + values = np.abs(values) + assert np.allclose(values, expected_angle, atol=1e-10) - # Both move in +y direction + def test_mean_angle_diagonal_motion_is_pi_over_four(self): + """Mean angle is pi/4 for diagonal (+x, +y) motion.""" data = np.array( [ - [[[0, 0]], [[0, 5]]], - [[[0, 0]], [[1, 6]]], - [[[0, 0]], [[2, 7]]], - [[[0, 0]], [[3, 8]]], - [[[0, 0]], [[4, 9]]], - [[[0, 0]], [[5, 10]]], + [[0, 5], [0, 5]], + [[1, 6], [1, 6]], + [[2, 7], [2, 7]], ], dtype=float, ) - - da = xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time, - "space": space, - "keypoints": keypoints, - "individuals": individuals, - }, + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, ) + assert np.allclose(mean_angle.values[1:], np.pi / 4, atol=1e-10) - pol, angle = kinematics.compute_polarization( - da, - displacement_frames=2, + def test_mean_angle_partial_alignment_matches_vector_average( + self, + partial_alignment_positions, + ): + """Mean angle matches vector average for partial alignment.""" + _, mean_angle = kinematics.compute_polarization( + partial_alignment_positions, return_angle=True, ) + expected = np.arctan2(1, 2) + assert np.allclose(mean_angle.values[1:], expected, atol=1e-10) - # First 2 frames are NaN - assert np.isnan(angle.values[0]) - assert np.isnan(angle.values[1]) - # Frames 2+ should be +y direction = π/2 - assert np.allclose(angle.values[2:], np.pi / 2, atol=1e-10) - - def test_return_angle_dimensions_match_polarization( - self, position_data_aligned_individuals + def test_mean_angle_is_nan_when_net_vector_cancels( + self, + opposite_positions, + perpendicular_positions, ): - """Test mean_angle has same dimensions and coords as polarization.""" - pol, angle = kinematics.compute_polarization( - position_data_aligned_individuals, + """Mean angle is NaN when heading vectors cancel out.""" + pol_opposite, angle_opposite = kinematics.compute_polarization( + opposite_positions, + return_angle=True, + ) + pol_perp, angle_perp = kinematics.compute_polarization( + perpendicular_positions, return_angle=True, ) + assert np.allclose(pol_opposite.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_perp.values[1:], 0.0, atol=1e-10) + assert np.all(np.isnan(angle_opposite.values[1:])) + assert np.all(np.isnan(angle_perp.values[1:])) - assert pol.dims == angle.dims - assert len(pol) == len(angle) - np.testing.assert_array_equal(pol.time.values, angle.time.values) + def test_mean_angle_with_keypoint_heading(self, keypoint_positions): + """Mean angle works correctly with keypoint-based heading.""" + polarization, mean_angle = kinematics.compute_polarization( + keypoint_positions, + heading_keypoints=("tail", "nose"), + return_angle=True, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) From 2bd4386850d208bd9e5cb0aa37d165ef4d5e4cf2 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 15:33:48 -0700 Subject: [PATCH 08/21] test(collective): add .sel() keypoint selection test, clarify docs --- movement/kinematics/collective.py | 18 ++++++--- .../test_kinematics/test_collective.py | 40 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index bb3258a92..efa6aec36 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -40,7 +40,9 @@ def compute_polarization( data : xarray.DataArray Position data. Must contain ``time``, ``space``, and ``individuals`` as dimensions. If ``heading_keypoints`` is provided, the array must also - contain a ``keypoints`` dimension. + contain a ``keypoints`` dimension. For displacement-based heading, + pre-select a keypoint (e.g., ``data.sel(keypoints="thorax")``) or the + first keypoint (index 0) will be used. Spatial coordinates must include ``"x"`` and ``"y"``. If additional spatial coordinates are present (e.g., ``"z"``), they are ignored; @@ -80,10 +82,6 @@ def compute_polarization( Examples -------- - Compute polarization from displacement: - - >>> polarization = compute_polarization(ds.position) - Compute polarization from keypoint-defined heading: >>> polarization = compute_polarization( @@ -91,6 +89,16 @@ def compute_polarization( ... heading_keypoints=("tail", "nose"), ... ) + Compute polarization from displacement (select keypoint for tracking): + + >>> polarization = compute_polarization( + ... ds.position.sel(keypoints="thorax") + ... ) + + If multiple keypoints exist and none is selected, the first is used: + + >>> polarization = compute_polarization(ds.position) + Return both polarization and mean angle: >>> polarization, mean_angle = compute_polarization( diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 8575a8b26..96c062836 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -543,6 +543,46 @@ def test_displacement_mode_with_keypoints_uses_first_keypoint(self): polarization = kinematics.compute_polarization(da) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + def test_explicit_keypoint_selection_with_sel(self): + """Pre-selecting keypoint with .sel() uses that keypoint. + + Data shape: (time, space, keypoints, individuals). + + X-coordinates across frames: + + Keypoint | Individual | Frame 0 | Frame 1 | Displacement + ---------|------------|---------|---------|------------- + thorax | ind0 | 0 | 1 | +1 (right) + thorax | ind1 | 10 | 11 | +1 (right) + head | ind0 | 0 | 1 | +1 (right) + head | ind1 | 10 | 9 | -1 (left) + + Thorax: both individuals move right -> polarization = 1.0 + Head: ind0 moves right, ind1 moves left -> polarization = 0.0 + """ + data = np.array( + [ + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) + + # Without .sel(): uses thorax -> both move right -> polarization = 1.0 + pol_default = kinematics.compute_polarization(da) + assert np.allclose(pol_default.values[1], 1.0, atol=1e-10) + + # With .sel(): head selected -> ind0 right, ind1 left -> 0.0 + pol_head = kinematics.compute_polarization(da.sel(keypoints="head")) + assert np.allclose(pol_head.values[1], 0.0, atol=1e-10) + def test_keypoint_heading_overrides_displacement_behavior(self): """Keypoint-based heading overrides displacement computation.""" data = np.array( From aef4ee99bbcfd46acf0018b7f23d5d9f2248a9f6 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 16:51:44 -0700 Subject: [PATCH 09/21] refactor(polarization): rename heading_keypoints to body_axis_keypoints, clarify orientation vs heading terminology --- movement/kinematics/collective.py | 81 ++++++++++++------- .../test_kinematics/test_collective.py | 70 ++++++++-------- 2 files changed, 85 insertions(+), 66 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index efa6aec36..c8b9c06dd 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -16,7 +16,7 @@ def compute_polarization( data: xr.DataArray, - heading_keypoints: tuple[Hashable, Hashable] | None = None, + body_axis_keypoints: tuple[Hashable, Hashable] | None = None, displacement_frames: int = 1, return_angle: bool = False, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: @@ -39,7 +39,7 @@ def compute_polarization( ---------- data : xarray.DataArray Position data. Must contain ``time``, ``space``, and ``individuals`` as - dimensions. If ``heading_keypoints`` is provided, the array must also + dimensions. If ``body_axis_keypoints`` is provided, the array must also contain a ``keypoints`` dimension. For displacement-based heading, pre-select a keypoint (e.g., ``data.sel(keypoints="thorax")``) or the first keypoint (index 0) will be used. @@ -47,16 +47,18 @@ def compute_polarization( Spatial coordinates must include ``"x"`` and ``"y"``. If additional spatial coordinates are present (e.g., ``"z"``), they are ignored; polarization is computed in the x/y plane. - heading_keypoints : tuple[Hashable, Hashable], optional + body_axis_keypoints : tuple[Hashable, Hashable], optional Pair of keypoint names ``(origin, target)`` used to compute heading as the vector from origin to target. If omitted, heading is inferred from displacement over ``displacement_frames``. displacement_frames : int, default=1 Number of frames used to compute displacement when - ``heading_keypoints`` is not provided. Must be a positive integer. - This parameter is ignored when ``heading_keypoints`` is provided. + ``body_axis_keypoints`` is not provided. Must be a positive integer. + This parameter is ignored when ``body_axis_keypoints`` is provided. return_angle : bool, default=False - If True, also return the mean heading angle in radians. + If True, also return the mean angle in radians. Returns the mean + body orientation angle when using ``body_axis_keypoints``, or the + mean heading angle when using displacement-based polarization. Returns ------- @@ -75,21 +77,23 @@ def compute_polarization( Zero-length headings are treated as invalid and excluded from the calculation. - The mean heading angle is defined from the summed unit-heading vector - projected onto the x/y plane. When no valid headings exist, or when the - summed heading vector has zero magnitude (for example exact cancellation), - the returned angle is NaN. + The mean angle is defined from the summed unit-heading vector projected + onto the x/y plane. When using ``body_axis_keypoints``, this represents + the mean body orientation; when using displacement, it represents the + mean movement direction. When no valid headings exist, or when the summed + heading vector has zero magnitude (for example exact cancellation), the + returned angle is NaN. Examples -------- - Compute polarization from keypoint-defined heading: + Compute orientation polarization from body-axis keypoints: >>> polarization = compute_polarization( ... ds.position, - ... heading_keypoints=("tail", "nose"), + ... body_axis_keypoints=("tail_base", "neck"), ... ) - Compute polarization from displacement (select keypoint for tracking): + Compute heading polarization from displacement (select keypoint for tracking): >>> polarization = compute_polarization( ... ds.position.sel(keypoints="thorax") @@ -99,7 +103,22 @@ def compute_polarization( >>> polarization = compute_polarization(ds.position) - Return both polarization and mean angle: + Return orientation polarization with mean body angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... return_angle=True, + ... ) + + Return heading polarization with mean movement angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... ) + + If multiple keypoints exist, first is used; also return mean angle: >>> polarization, mean_angle = compute_polarization( ... ds.position, @@ -110,13 +129,13 @@ def compute_polarization( _validate_type_data_array(data) normalized_keypoints = _validate_position_data( data=data, - heading_keypoints=heading_keypoints, + body_axis_keypoints=body_axis_keypoints, ) if normalized_keypoints is not None: heading_vectors = _compute_heading_from_keypoints( data=data, - heading_keypoints=normalized_keypoints, + body_axis_keypoints=normalized_keypoints, ) else: heading_vectors = _compute_heading_from_velocity( @@ -158,10 +177,10 @@ def compute_polarization( def _compute_heading_from_keypoints( data: xr.DataArray, - heading_keypoints: tuple[Hashable, Hashable], + body_axis_keypoints: tuple[Hashable, Hashable], ) -> xr.DataArray: """Compute heading vectors from two keypoints (origin to target).""" - origin, target = heading_keypoints + origin, target = body_axis_keypoints heading = data.sel(keypoints=target, drop=True) - data.sel( keypoints=origin, drop=True, @@ -205,9 +224,9 @@ def _select_xy(data: xr.DataArray) -> xr.DataArray: def _validate_position_data( data: xr.DataArray, - heading_keypoints: tuple[Hashable, Hashable] | None, + body_axis_keypoints: tuple[Hashable, Hashable] | None, ) -> tuple[Hashable, Hashable] | None: - """Validate the input array and normalize ``heading_keypoints``.""" + """Validate the input array and normalize ``body_axis_keypoints``.""" validate_dims_coords( data, { @@ -236,45 +255,45 @@ def _validate_position_data( "data.space must include coordinate labels 'x' and 'y'." ) - if heading_keypoints is None: + if body_axis_keypoints is None: return None - origin, target = _normalize_heading_keypoints(heading_keypoints) + origin, target = _normalize_body_axis_keypoints(body_axis_keypoints) if "keypoints" not in data.dims: raise ValueError( - "heading_keypoints requires data to have a 'keypoints' dimension." + "body_axis_keypoints requires data to have a 'keypoints' dimension." ) validate_dims_coords(data, {"keypoints": [origin, target]}) return origin, target -def _normalize_heading_keypoints( - heading_keypoints: tuple[Hashable, Hashable] | Any, +def _normalize_body_axis_keypoints( + body_axis_keypoints: tuple[Hashable, Hashable] | Any, ) -> tuple[Hashable, Hashable]: """Validate and normalize the keypoint pair.""" - if isinstance(heading_keypoints, (str, bytes)): + if isinstance(body_axis_keypoints, (str, bytes)): raise TypeError( - "heading_keypoints must be an iterable of exactly two " + "body_axis_keypoints must be an iterable of exactly two " "keypoint names." ) try: - origin, target = heading_keypoints + origin, target = body_axis_keypoints except (TypeError, ValueError) as exc: raise TypeError( - "heading_keypoints must be an iterable of exactly two " + "body_axis_keypoints must be an iterable of exactly two " "keypoint names." ) from exc for keypoint in (origin, target): if not isinstance(keypoint, Hashable): - raise TypeError("Each heading keypoint must be hashable.") + raise TypeError("Each body axis keypoint must be hashable.") if origin == target: raise ValueError( - "heading_keypoints must contain two distinct keypoint names." + "body_axis_keypoints must contain two distinct keypoint names." ) return origin, target diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 96c062836..eefe28b83 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -127,7 +127,7 @@ def perpendicular_positions() -> xr.DataArray: @pytest.fixture def keypoint_positions() -> xr.DataArray: - """Two individuals with tail/nose keypoints, both facing +x.""" + """Two individuals with tail_base/neck keypoints, both facing +x.""" data = np.array( [ [ @@ -145,7 +145,7 @@ def keypoint_positions() -> xr.DataArray: ], dtype=float, ) - return _make_position_dataarray(data, keypoints=["tail", "nose"]) + return _make_position_dataarray(data, keypoints=["tail_base", "neck"]) class TestComputePolarizationValidation: @@ -203,61 +203,61 @@ def test_requires_x_and_y_space_labels(self): kinematics.compute_polarization(data) @pytest.mark.parametrize( - "heading_keypoints", + "body_axis_keypoints", [ - "nose", - ("tail",), - ("tail", "nose", "ear"), + "neck", + ("tail_base",), + ("tail_base", "neck", "ear"), 123, ], ids=["string", "length_one", "length_three", "non_iterable"], ) - def test_heading_keypoints_must_be_length_two_iterable( + def test_body_axis_keypoints_must_be_length_two_iterable( self, - heading_keypoints, + body_axis_keypoints, keypoint_positions, ): - """Raise TypeError if heading_keypoints is not length-two.""" + """Raise TypeError if body_axis_keypoints is not length-two.""" with pytest.raises(TypeError, match="exactly two keypoint names"): kinematics.compute_polarization( keypoint_positions, - heading_keypoints=heading_keypoints, + body_axis_keypoints=body_axis_keypoints, ) - def test_heading_keypoints_must_be_hashable(self, keypoint_positions): - """Raise TypeError if heading keypoints are not hashable.""" + def test_body_axis_keypoints_must_be_hashable(self, keypoint_positions): + """Raise TypeError if body axis keypoints are not hashable.""" with pytest.raises(TypeError, match="hashable"): kinematics.compute_polarization( keypoint_positions, - heading_keypoints=(["tail"], "nose"), + body_axis_keypoints=(["tail_base"], "neck"), ) - def test_heading_keypoints_require_keypoints_dimension( + def test_body_axis_keypoints_require_keypoints_dimension( self, aligned_positions ): - """Raise ValueError if heading_keypoints given but no keypoints dim.""" + """Raise ValueError if body_axis_keypoints given but no keypoints dim.""" with pytest.raises( ValueError, match="requires data to have a 'keypoints' dimension" ): kinematics.compute_polarization( aligned_positions, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), ) - def test_heading_keypoints_must_exist(self, keypoint_positions): + def test_body_axis_keypoints_must_exist(self, keypoint_positions): """Raise ValueError if specified keypoints do not exist in data.""" with pytest.raises(ValueError, match="snout|keypoints"): kinematics.compute_polarization( keypoint_positions, - heading_keypoints=("tail", "snout"), + body_axis_keypoints=("tail_base", "snout"), ) - def test_heading_keypoints_must_be_distinct(self, keypoint_positions): + def test_body_axis_keypoints_must_be_distinct(self, keypoint_positions): """Raise ValueError if origin and target keypoints are identical.""" with pytest.raises(ValueError, match="two distinct keypoint names"): kinematics.compute_polarization( keypoint_positions, - heading_keypoints=("tail", "tail"), + body_axis_keypoints=("tail_base", "tail_base"), ) @pytest.mark.parametrize( @@ -290,7 +290,7 @@ def test_invalid_displacement_frames_is_ignored_in_keypoint_mode( """Invalid displacement_frames is ignored when keypoints are used.""" polarization = kinematics.compute_polarization( keypoint_positions, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), displacement_frames=0, ) assert np.allclose(polarization.values, 1.0, atol=1e-10) @@ -414,7 +414,7 @@ def test_one_coordinate_nan_excludes_that_individual(self): assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) - def test_nan_in_keypoint_heading_excludes_that_individual(self): + def test_nan_in_body_axis_heading_excludes_that_individual(self): """NaN in keypoint position excludes that individual.""" data = np.array( [ @@ -433,10 +433,10 @@ def test_nan_in_keypoint_heading_excludes_that_individual(self): ], dtype=float, ) - da = _make_position_dataarray(data, keypoints=["tail", "nose"]) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) polarization = kinematics.compute_polarization( da, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), ) assert np.allclose(polarization.values[[0, 2]], 1.0, atol=1e-10) assert np.allclose(polarization.values[1], 1.0, atol=1e-10) @@ -508,13 +508,13 @@ def test_polarization_is_invariant_to_individual_order(self): class TestHeadingSourceSelection: """Tests for heading computation mode selection.""" - def test_keypoint_heading_is_valid_on_first_frame( + def test_body_axis_heading_is_valid_on_first_frame( self, keypoint_positions ): - """Keypoint-based heading produces valid values on first frame.""" + """Body-axis heading produces valid values on first frame.""" polarization, mean_angle = kinematics.compute_polarization( keypoint_positions, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) assert np.allclose(polarization.values, 1.0, atol=1e-10) @@ -539,7 +539,7 @@ def test_displacement_mode_with_keypoints_uses_first_keypoint(self): ], dtype=float, ) - da = _make_position_dataarray(data, keypoints=["centroid", "nose"]) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) polarization = kinematics.compute_polarization(da) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) @@ -583,8 +583,8 @@ def test_explicit_keypoint_selection_with_sel(self): pol_head = kinematics.compute_polarization(da.sel(keypoints="head")) assert np.allclose(pol_head.values[1], 0.0, atol=1e-10) - def test_keypoint_heading_overrides_displacement_behavior(self): - """Keypoint-based heading overrides displacement computation.""" + def test_body_axis_heading_overrides_displacement_behavior(self): + """Body-axis heading overrides displacement computation.""" data = np.array( [ [ @@ -598,10 +598,10 @@ def test_keypoint_heading_overrides_displacement_behavior(self): ], dtype=float, ) - da = _make_position_dataarray(data, keypoints=["tail", "nose"]) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) polarization = kinematics.compute_polarization( da, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), displacement_frames=1000, ) assert np.allclose(polarization.values, 1.0, atol=1e-10) @@ -817,11 +817,11 @@ def test_mean_angle_is_nan_when_net_vector_cancels( assert np.all(np.isnan(angle_opposite.values[1:])) assert np.all(np.isnan(angle_perp.values[1:])) - def test_mean_angle_with_keypoint_heading(self, keypoint_positions): - """Mean angle works correctly with keypoint-based heading.""" + def test_mean_angle_with_body_axis_heading(self, keypoint_positions): + """Mean angle works correctly with body-axis heading.""" polarization, mean_angle = kinematics.compute_polarization( keypoint_positions, - heading_keypoints=("tail", "nose"), + body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) assert np.allclose(polarization.values, 1.0, atol=1e-10) From 6efe00ee960206e8876c6df3c083bda41d2dec75 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 17:00:46 -0700 Subject: [PATCH 10/21] linting fix --- movement/kinematics/collective.py | 4 ++-- tests/test_unit/test_kinematics/test_collective.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index c8b9c06dd..7858eebc9 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -93,7 +93,7 @@ def compute_polarization( ... body_axis_keypoints=("tail_base", "neck"), ... ) - Compute heading polarization from displacement (select keypoint for tracking): + Compute heading polarization from displacement (pre-select keypoint): >>> polarization = compute_polarization( ... ds.position.sel(keypoints="thorax") @@ -262,7 +262,7 @@ def _validate_position_data( if "keypoints" not in data.dims: raise ValueError( - "body_axis_keypoints requires data to have a 'keypoints' dimension." + "body_axis_keypoints requires a 'keypoints' dimension in data." ) validate_dims_coords(data, {"keypoints": [origin, target]}) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index eefe28b83..52192fb77 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -235,9 +235,9 @@ def test_body_axis_keypoints_must_be_hashable(self, keypoint_positions): def test_body_axis_keypoints_require_keypoints_dimension( self, aligned_positions ): - """Raise ValueError if body_axis_keypoints given but no keypoints dim.""" + """Raise ValueError if body_axis_keypoints given without keypoints.""" with pytest.raises( - ValueError, match="requires data to have a 'keypoints' dimension" + ValueError, match="requires a 'keypoints' dimension" ): kinematics.compute_polarization( aligned_positions, From 8182d0538ad779c5051a64347b27bb000c033112 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 17:40:32 -0700 Subject: [PATCH 11/21] docs(polarization) use neutral u_hat notation for unit direction vector --- movement/kinematics/collective.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 7858eebc9..8f8e74469 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -30,9 +30,9 @@ def compute_polarization( .. math:: - \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{v}_i \right\| + \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{u}_i \right\| - where :math:`\hat{v}_i` is the unit heading vector for individual + where :math:`\hat{u}_i` is the unit direction vector for individual :math:`i`, and :math:`N` is the number of valid individuals at that time. Parameters From 96ef674f8ae4a138d1ae900b942b89ae91488972 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 19:27:17 -0700 Subject: [PATCH 12/21] test(collective): add mathematical invariance, edge case, & validation for polarization --- .../test_kinematics/test_collective.py | 325 ++++++++++++++++-- 1 file changed, 300 insertions(+), 25 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 52192fb77..d07b4da20 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -295,6 +295,39 @@ def test_invalid_displacement_frames_is_ignored_in_keypoint_mode( ) assert np.allclose(polarization.values, 1.0, atol=1e-10) + def test_requires_space_coordinate_labels_to_exist(self): + """Raise ValueError if the space dimension has no coordinate labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises( + ValueError, + match="coordinate labels for the 'space' dimension", + ): + kinematics.compute_polarization(data) + + def test_empty_keypoints_dimension_raises_in_displacement_mode(self): + """Raise if keypoints dimension exists but contains no entries.""" + data = xr.DataArray( + np.empty((3, 2, 0, 2)), + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "keypoints": [], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises(ValueError, match="at least one keypoint"): + kinematics.compute_polarization(data) + class TestComputePolarizationBehavior: """Tests for polarization computation behavior.""" @@ -344,23 +377,6 @@ def test_single_individual_gives_one(self): ) assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - def test_large_n_aligned_gives_one(self): - """Polarization is 1.0 for 50 aligned individuals.""" - n_individuals = 50 - x_coords = np.arange(n_individuals, dtype=float) - data = np.array( - [ - [x_coords, np.zeros(n_individuals)], - [x_coords + 1, np.zeros(n_individuals)], - [x_coords + 2, np.zeros(n_individuals)], - ], - dtype=float, - ) - polarization = kinematics.compute_polarization( - _make_position_dataarray(data) - ) - assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) - def test_stationary_individuals_are_excluded(self): """Stationary individuals produce NaN polarization and angle.""" data = np.array( @@ -504,14 +520,214 @@ def test_polarization_is_invariant_to_individual_order(self): pol_original.values, pol_permuted.values, atol=1e-10 ) + def test_zero_length_body_axis_vectors_are_excluded(self): + """Zero-length body-axis headings are excluded as invalid.""" + # ind0 has coincident tail_base and neck (zero-length heading) + # ind1 has valid +x body axis heading + data = np.array( + [ + [ + [[0.0, 10.0], [0.0, 11.0]], # x: ind0 zero-length, ind1 +1 + [[0.0, 0.0], [0.0, 0.0]], # y + ], + [ + [[0.0, 10.5], [0.0, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + polarization, mean_angle = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + def test_polarization_is_invariant_to_translation( + self, + partial_alignment_positions, + ): + """Adding a constant offset does not change polarization.""" + shifted = partial_alignment_positions.copy() + shifted.loc[dict(space="x")] = shifted.sel(space="x") + 1000.0 + shifted.loc[dict(space="y")] = shifted.sel(space="y") - 500.0 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_shifted = kinematics.compute_polarization(shifted) + + np.testing.assert_allclose( + pol_original.values, + pol_shifted.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_positive_scaling( + self, + partial_alignment_positions, + ): + """Positive scalar multiplication preserves polarization.""" + scaled = partial_alignment_positions * 7.5 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_scaled = kinematics.compute_polarization(scaled) + + np.testing.assert_allclose( + pol_original.values, + pol_scaled.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_global_rotation( + self, + partial_alignment_positions, + ): + """A global planar rotation preserves polarization magnitude.""" + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[dict(space="x")] = -y + rotated.loc[dict(space="y")] = x + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_rotated = kinematics.compute_polarization(rotated) + + np.testing.assert_allclose( + pol_original.values, + pol_rotated.values, + atol=1e-10, + equal_nan=True, + ) + + def test_body_axis_invariance_to_translation_scaling_rotation( + self, + ): + """Body-axis polarization is invariant to translation/scaling/rotation. + + Mean body angle is invariant to translation and positive scaling, and + rotates by the same amount under global planar rotation. + """ + # Three individuals with body axes: +x, +x, +y. + # This gives a nontrivial baseline: + # vector sum = (2, 1) + # polarization = sqrt(5) / 3 + # mean angle = atan2(1, 2) + # + # Absolute positions differ across frames to ensure we are really + # testing body-axis heading (target - origin), not any accidental + # dependence on absolute location. + data = np.array( + [ + [ + [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], # x + [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], # y + ], + [ + [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], # x + [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], # y + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + pol_base, angle_base = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + expected_pol = np.sqrt(5) / 3 + expected_angle = np.arctan2(1.0, 2.0) + + np.testing.assert_allclose(pol_base.values, expected_pol, atol=1e-10) + np.testing.assert_allclose( + angle_base.values, expected_angle, atol=1e-10 + ) + + # Global translation: should not affect body-axis vectors. + translated = da.copy() + translated.loc[dict(space="x")] = translated.sel(space="x") + 123.4 + translated.loc[dict(space="y")] = translated.sel(space="y") - 56.7 + + pol_translated, angle_translated = kinematics.compute_polarization( + translated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_translated.values, pol_base.values, atol=1e-10 + ) + np.testing.assert_allclose( + angle_translated.values, angle_base.values, atol=1e-10 + ) + + # Positive scaling: should preserve directions and therefore preserve + # polarization and angle. + scaled = da * 4.2 + + pol_scaled, angle_scaled = kinematics.compute_polarization( + scaled, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_scaled.values, pol_base.values, atol=1e-10 + ) + np.testing.assert_allclose( + angle_scaled.values, angle_base.values, atol=1e-10 + ) + + # Global 90-degree rotation: polarization magnitude should be + # unchanged, and mean angle should rotate by +pi/2 (with wraparound). + rotated = da.copy() + x = da.sel(space="x") + y = da.sel(space="y") + rotated.loc[dict(space="x")] = -y + rotated.loc[dict(space="y")] = x + + pol_rotated, angle_rotated = kinematics.compute_polarization( + rotated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_rotated.values, pol_base.values, atol=1e-10 + ) + + expected_rotated_angle = angle_base.values + (np.pi / 2) + expected_rotated_angle = ( + (expected_rotated_angle + np.pi) % (2 * np.pi) + ) - np.pi + + np.testing.assert_allclose( + angle_rotated.values, expected_rotated_angle, atol=1e-10 + ) + class TestHeadingSourceSelection: """Tests for heading computation mode selection.""" - def test_body_axis_heading_is_valid_on_first_frame( + def test_body_axis_heading_valid_on_first_frame_returns_expected_angle( self, keypoint_positions ): - """Body-axis heading produces valid values on first frame.""" + """Body-axis heading is valid from frame 0 and returns angle 0.""" polarization, mean_angle = kinematics.compute_polarization( keypoint_positions, body_axis_keypoints=("tail_base", "neck"), @@ -688,6 +904,19 @@ def test_larger_displacement_window_can_change_alignment_estimate(self): assert np.allclose(pol_1frame.values[1:], 0.0, atol=1e-10) assert np.allclose(pol_2frame.values[2:], 1.0, atol=1e-10) + def test_displacement_frames_larger_than_time_axis_returns_all_nan( + self, + aligned_positions, + ): + """Oversized displacement windows produce no valid headings.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=10, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + class TestReturnAngle: """Tests for return_angle parameter behavior.""" @@ -817,12 +1046,58 @@ def test_mean_angle_is_nan_when_net_vector_cancels( assert np.all(np.isnan(angle_opposite.values[1:])) assert np.all(np.isnan(angle_perp.values[1:])) - def test_mean_angle_with_body_axis_heading(self, keypoint_positions): - """Mean angle works correctly with body-axis heading.""" + def test_mean_angle_rotates_with_global_rotation( + self, + partial_alignment_positions, + ): + """Mean angle shifts by the same amount under global rotation.""" + _, angle_original = kinematics.compute_polarization( + partial_alignment_positions, + return_angle=True, + ) + + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[dict(space="x")] = -y + rotated.loc[dict(space="y")] = x + + _, angle_rotated = kinematics.compute_polarization( + rotated, + return_angle=True, + ) + + expected = angle_original.values[1:] + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + + np.testing.assert_allclose( + angle_rotated.values[1:], + expected, + atol=1e-10, + ) + + def test_mean_angle_wraparound_near_pi_is_handled_correctly(self): + """Headings near +pi and -pi should average leftward, not to zero.""" + # Two individuals moving left with tiny y-offsets in opposite dirs. + # This creates headings very close to +pi and -pi. + data = np.array( + [ + [[0.0, 0.0], [0.0, 0.0]], + [[-1.0, -1.0], [1e-6, -1e-6]], + [[-2.0, -2.0], [2e-6, -2e-6]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( - keypoint_positions, - body_axis_keypoints=("tail_base", "neck"), + _make_position_dataarray(data), return_angle=True, ) - assert np.allclose(polarization.values, 1.0, atol=1e-10) - assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose( + np.abs(mean_angle.values[1:]), + np.pi, + atol=1e-6, + ) From 7d61af5ea1d4da73c09ca2438a22c7abbf219d40 Mon Sep 17 00:00:00 2001 From: khan-u Date: Wed, 18 Mar 2026 21:36:05 -0700 Subject: [PATCH 13/21] feat(polarization): add in_degrees parameter + unit test --- movement/kinematics/collective.py | 36 ++++++++++++++----- .../test_kinematics/test_collective.py | 25 +++++++++++++ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 8f8e74469..8b0e06aa5 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -19,12 +19,15 @@ def compute_polarization( body_axis_keypoints: tuple[Hashable, Hashable] | None = None, displacement_frames: int = 1, return_angle: bool = False, + in_degrees: bool = False, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: r"""Compute polarization (group alignment) of individuals. - Polarization measures how aligned the heading directions of individuals - are. A value of 1 indicates perfect alignment, while a value near 0 - indicates weak or canceling alignment. + Polarization measures how aligned individuals' direction vectors are, + supporting two modes: **orientation polarization** (body-axis mode) for + body orientation alignment, and **heading polarization** (displacement + mode) for movement direction alignment. A value of 1 indicates perfect + alignment, while a value near 0 indicates weak or canceling alignment. The polarization is computed as @@ -56,9 +59,13 @@ def compute_polarization( ``body_axis_keypoints`` is not provided. Must be a positive integer. This parameter is ignored when ``body_axis_keypoints`` is provided. return_angle : bool, default=False - If True, also return the mean angle in radians. Returns the mean - body orientation angle when using ``body_axis_keypoints``, or the - mean heading angle when using displacement-based polarization. + If True, also return the mean angle. Returns the mean body + orientation angle when using ``body_axis_keypoints``, or the mean + movement direction angle when using displacement-based polarization. + in_degrees : bool, default=False + If True, the mean angle is returned in degrees. Otherwise, the + angle is returned in radians. Only relevant when + ``return_angle=True``. Returns ------- @@ -103,7 +110,7 @@ def compute_polarization( >>> polarization = compute_polarization(ds.position) - Return orientation polarization with mean body angle: + Return orientation polarization with mean body orientation angle: >>> polarization, mean_angle = compute_polarization( ... ds.position, @@ -111,13 +118,21 @@ def compute_polarization( ... return_angle=True, ... ) - Return heading polarization with mean movement angle: + Return heading polarization with mean movement direction angle (radians): >>> polarization, mean_angle = compute_polarization( ... ds.position.sel(keypoints="thorax"), ... return_angle=True, ... ) + Return heading polarization with mean movement direction angle (degrees): + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... in_degrees=True, + ... ) + If multiple keypoints exist, first is used; also return mean angle: >>> polarization, mean_angle = compute_polarization( @@ -170,7 +185,10 @@ def compute_polarization( vector_sum.sel(space="x"), ), np.nan, - ).rename("mean_angle") + ) + if in_degrees: + mean_angle = np.rad2deg(mean_angle) + mean_angle = mean_angle.rename("mean_angle") return polarization, mean_angle diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index d07b4da20..740f1c45c 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -1101,3 +1101,28 @@ def test_mean_angle_wraparound_near_pi_is_handled_correctly(self): np.pi, atol=1e-6, ) + + def test_in_degrees_true_returns_degrees(self): + """in_degrees=True returns angle in degrees.""" + # Two individuals moving in +y direction + data = np.array( + [ + [[0, 0], [0, 0]], + [[0, 0], [1, 1]], + [[0, 0], [2, 2]], + ], + dtype=float, + ) + _, mean_angle_rad = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=False, + ) + _, mean_angle_deg = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=True, + ) + # +y direction = 90 degrees = pi/2 radians + assert np.allclose(mean_angle_rad.values[1:], np.pi / 2, atol=1e-10) + assert np.allclose(mean_angle_deg.values[1:], 90.0, atol=1e-10) From 9313b9da4d1fdae1b792cabda944629b3f130e16 Mon Sep 17 00:00:00 2001 From: khan-u Date: Thu, 19 Mar 2026 18:19:22 -0700 Subject: [PATCH 14/21] SonarCloud warning fixes --- .../test_kinematics/test_collective.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 740f1c45c..8a785d115 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -554,8 +554,8 @@ def test_polarization_is_invariant_to_translation( ): """Adding a constant offset does not change polarization.""" shifted = partial_alignment_positions.copy() - shifted.loc[dict(space="x")] = shifted.sel(space="x") + 1000.0 - shifted.loc[dict(space="y")] = shifted.sel(space="y") - 500.0 + shifted.loc[{"space": "x"}] = shifted.sel(space="x") + 1000.0 + shifted.loc[{"space": "y"}] = shifted.sel(space="y") - 500.0 pol_original = kinematics.compute_polarization( partial_alignment_positions @@ -597,8 +597,8 @@ def test_polarization_is_invariant_to_global_rotation( y = partial_alignment_positions.sel(space="y") rotated = partial_alignment_positions.copy() - rotated.loc[dict(space="x")] = -y - rotated.loc[dict(space="y")] = x + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x pol_original = kinematics.compute_polarization( partial_alignment_positions @@ -660,8 +660,8 @@ def test_body_axis_invariance_to_translation_scaling_rotation( # Global translation: should not affect body-axis vectors. translated = da.copy() - translated.loc[dict(space="x")] = translated.sel(space="x") + 123.4 - translated.loc[dict(space="y")] = translated.sel(space="y") - 56.7 + translated.loc[{"space": "x"}] = translated.sel(space="x") + 123.4 + translated.loc[{"space": "y"}] = translated.sel(space="y") - 56.7 pol_translated, angle_translated = kinematics.compute_polarization( translated, @@ -698,8 +698,8 @@ def test_body_axis_invariance_to_translation_scaling_rotation( rotated = da.copy() x = da.sel(space="x") y = da.sel(space="y") - rotated.loc[dict(space="x")] = -y - rotated.loc[dict(space="y")] = x + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x pol_rotated, angle_rotated = kinematics.compute_polarization( rotated, @@ -1060,8 +1060,8 @@ def test_mean_angle_rotates_with_global_rotation( y = partial_alignment_positions.sel(space="y") rotated = partial_alignment_positions.copy() - rotated.loc[dict(space="x")] = -y - rotated.loc[dict(space="y")] = x + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x _, angle_rotated = kinematics.compute_polarization( rotated, From 6dd7fbb028691e681f6e9a8645860e8640cb1897 Mon Sep 17 00:00:00 2001 From: khan-u Date: Fri, 3 Apr 2026 18:21:01 -0700 Subject: [PATCH 15/21] refactor(collective): use more vector.py utilities for polarization computation --- movement/kinematics/collective.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 8b0e06aa5..1977eeb4d 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -8,7 +8,11 @@ import xarray as xr from movement.utils.logging import logger -from movement.utils.vector import compute_norm +from movement.utils.vector import ( + compute_norm, + compute_signed_angle_2d, + convert_to_unit, +) from movement.validators.arrays import validate_dims_coords _ANGLE_EPS = 1e-12 @@ -158,11 +162,10 @@ def compute_polarization( displacement_frames=displacement_frames, ) - heading_xy = _select_xy(heading_vectors) - norm = compute_norm(heading_xy) - valid_mask = (~heading_xy.isnull().any(dim="space")) & (norm > 0) + heading = _select_space(heading_vectors) - unit_headings = (heading_xy / norm).where(valid_mask) + unit_headings = convert_to_unit(heading) + valid_mask = ~unit_headings.isnull().any(dim="space") vector_sum = unit_headings.sum(dim="individuals", skipna=True) sum_magnitude = compute_norm(vector_sum) n_valid = valid_mask.sum(dim="individuals") @@ -177,12 +180,16 @@ def compute_polarization( if not return_angle: return polarization + # Normalize vector_sum to unit vector for angle computation + mean_unit_vector = vector_sum / sum_magnitude + + # Compute angle from positive x-axis to mean unit vector + reference = np.array([1, 0]) angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) mean_angle = xr.where( angle_defined, - np.arctan2( - vector_sum.sel(space="y"), - vector_sum.sel(space="x"), + compute_signed_angle_2d( + mean_unit_vector, reference, v_as_left_operand=True ), np.nan, ) @@ -235,9 +242,9 @@ def _compute_heading_from_velocity( return displacement -def _select_xy(data: xr.DataArray) -> xr.DataArray: - """Select the planar x/y components and return standard dim order.""" - return data.sel(space=["x", "y"]).transpose("time", "space", "individuals") +def _select_space(data: xr.DataArray) -> xr.DataArray: + """Return data with standard dim order, preserving all spatial coords.""" + return data.transpose("time", "space", "individuals") def _validate_position_data( From 51866c9b75f3cb52a37ed45f6c9c17d248eecc48 Mon Sep 17 00:00:00 2001 From: khan-u Date: Thu, 2 Apr 2026 01:58:25 -0700 Subject: [PATCH 16/21] feat(collective): add prior-free body-axis inference --- movement/kinematics/collective.py | 2921 ++++++++++++++++- .../test_kinematics/test_collective.py | 249 +- 2 files changed, 3060 insertions(+), 110 deletions(-) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 1977eeb4d..fa136bedb 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -2,6 +2,7 @@ """Compute collective behavior metrics for multi-individual tracking data.""" from collections.abc import Hashable +from dataclasses import dataclass, field from typing import Any import numpy as np @@ -18,12 +19,266 @@ _ANGLE_EPS = 1e-12 +@dataclass +class _ValidateAPConfig: + """Configuration for the _validate_ap function. + + Parameters + ---------- + min_valid_frac : float, default=0.6 + Minimum fraction of keypoints that must be present for a frame + to qualify as tier-1 valid. + window_len : int, default=50 + Number of speed samples per sliding window. + stride : int, default=5 + Step size between consecutive sliding window start positions. + pct_thresh : float, default=85.0 + Percentile threshold applied to valid-window median speeds for + high-motion classification. + min_run_len : int, default=1 + Minimum number of consecutive qualifying windows required to + form a valid run. + postural_var_ratio_thresh : float, default=2.0 + Between-segment to within-segment RMSD variance ratio above which + postural clustering is triggered. + max_clusters : int, default=4 + Upper bound on the number of clusters to evaluate during k-medoids. + confidence_floor : float, default=0.1 + Vote margin below which the anterior inference is flagged as + unreliable. + lateral_thresh : float, default=0.4 + Normalized lateral offset ceiling for the Step 1 lateral alignment + filter. + edge_thresh : float, default=0.3 + Normalized midpoint distance floor for the Step 3 distal/proximal + classification. + + """ + + min_valid_frac: float = 0.6 + window_len: int = 50 + stride: int = 5 + pct_thresh: float = 85.0 + min_run_len: int = 1 + postural_var_ratio_thresh: float = 2.0 + max_clusters: int = 4 + confidence_floor: float = 0.1 + lateral_thresh: float = 0.4 + edge_thresh: float = 0.3 + + def __post_init__(self) -> None: + """Validate configuration parameters.""" + # Validate fraction parameters (must be in [0, 1]) + for name in ( + "min_valid_frac", + "confidence_floor", + "lateral_thresh", + "edge_thresh", + ): + value = getattr(self, name) + if not (0 <= value <= 1): + raise ValueError( + f"{name} must be between 0 and 1, got {value}" + ) + + # Validate positive integer parameters + for name in ("window_len", "stride", "min_run_len", "max_clusters"): + value = getattr(self, name) + if not isinstance(value, int) or value <= 0: + raise ValueError( + f"{name} must be a positive integer, got {value}" + ) + + # Validate pct_thresh (must be in [0, 100]) + if not (0 <= self.pct_thresh <= 100): + raise ValueError( + f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" + ) + + # Validate postural_var_ratio_thresh (must be positive) + if self.postural_var_ratio_thresh <= 0: + raise ValueError( + f"postural_var_ratio_thresh must be positive, " + f"got {self.postural_var_ratio_thresh}" + ) + + +@dataclass +class _FrameSelection: + """Selected frames from high-motion segmentation and tier-2 filtering. + + Bundles the frame indices, segment assignments, and related arrays + produced by the segmentation pipeline for downstream consumption + (skeleton construction, postural clustering, velocity recomputation). + + Attributes + ---------- + frames : np.ndarray + Array of selected frame indices (tier-2 valid, within segments). + seg_ids : np.ndarray + Segment ID (0-indexed) for each selected frame. + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + count : int + Number of selected frames. + + """ + + frames: np.ndarray + seg_ids: np.ndarray + segments: np.ndarray + bbox_centroids: np.ndarray + count: int + + +@dataclass +class _APNodePairReport: + """Report from the AP node-pair evaluation pipeline. + + This dataclass holds all results from the 3-step filter cascade + used to evaluate a candidate anterior-posterior keypoint pair. + + Attributes + ---------- + success : bool + Whether the evaluation pipeline completed successfully. + failure_step : str + Name of the step at which evaluation failed, if any. + failure_reason : str + Reason for failure, if any. + scenario : int + Scenario number (1-13) from the mutually exclusive outcomes. + outcome : str + Either "accept" or "warn". + warning_message : str + Warning message, if applicable. + sorted_candidate_nodes : np.ndarray + Indices of candidate nodes after Step 1 filtering, sorted by + ascending normalized lateral offset. + valid_pairs : np.ndarray + Array of shape (n_pairs, 2) containing valid node pairs after + Step 2 filtering. + valid_pairs_internode_dist : np.ndarray + Internode separation (AP distance) for each valid pair. + input_pair_in_candidates : bool + Whether the input pair survived Step 1 filtering. + input_pair_opposite_sides : bool + Whether the input pair lies on opposite sides of the midpoint. + input_pair_separation_abs : float + Absolute AP separation of the input pair. + input_pair_is_distal : bool + Whether the input pair is classified as distal in Step 3. + input_pair_rank : int + Rank of the input pair by internode separation (1 = largest). + input_pair_order_matches_inference : bool + Whether from_node has a lower AP coordinate than to_node + (i.e. from_node is more posterior). True means the input pair + ordering is consistent with the inferred AP axis. + pc1_coords : np.ndarray + PC1 coordinates for each keypoint. + ap_coords : np.ndarray + AP (anterior-posterior) coordinates for each keypoint. + lateral_offsets : np.ndarray + Unsigned lateral offset from body axis for each keypoint. + lateral_offsets_norm : np.ndarray + Normalized lateral offsets (0 = nearest to axis, 1 = farthest). + lateral_offset_min : float + Minimum lateral offset among valid keypoints. + lateral_offset_max : float + Maximum lateral offset among valid keypoints. + midpoint_pc1 : float + AP reference midpoint (average of min and max PC1 projections). + pc1_min : float + Minimum PC1 projection among valid keypoints. + pc1_max : float + Maximum PC1 projection among valid keypoints. + midline_dist_norm : np.ndarray + Normalized distance from midpoint for each keypoint. + midline_dist_max : float + Maximum absolute distance from midpoint. + distal_pairs : np.ndarray + Array of distal pairs (both nodes at or above edge_thresh). + proximal_pairs : np.ndarray + Array of proximal pairs (at least one node below edge_thresh). + max_separation_distal_nodes : np.ndarray + Node indices of the maximum-separation distal pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation_distal : float + Internode separation of the max-separation distal pair. + max_separation_nodes : np.ndarray + Node indices of the overall maximum-separation pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation : float + Internode separation of the overall max-separation pair. + + """ + + success: bool = False + failure_step: str = "" + failure_reason: str = "" + scenario: int = 0 + outcome: str = "" + warning_message: str = "" + + sorted_candidate_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + valid_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + valid_pairs_internode_dist: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + + input_pair_in_candidates: bool = False + input_pair_opposite_sides: bool = False + input_pair_separation_abs: float = np.nan + input_pair_is_distal: bool = False + input_pair_rank: int = 0 + input_pair_order_matches_inference: bool = False + + pc1_coords: np.ndarray = field(default_factory=lambda: np.array([])) + ap_coords: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets_norm: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + lateral_offset_min: float = np.nan + lateral_offset_max: float = np.nan + midpoint_pc1: float = np.nan + pc1_min: float = np.nan + pc1_max: float = np.nan + midline_dist_norm: np.ndarray = field(default_factory=lambda: np.array([])) + midline_dist_max: float = np.nan + + distal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + proximal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + max_separation_distal_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation_distal: float = np.nan + max_separation_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation: float = np.nan + + def compute_polarization( data: xr.DataArray, body_axis_keypoints: tuple[Hashable, Hashable] | None = None, displacement_frames: int = 1, return_angle: bool = False, in_degrees: bool = False, + validate_ap: bool = True, + ap_validation_config: dict[str, Any] | None = None, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: r"""Compute polarization (group alignment) of individuals. @@ -70,6 +325,13 @@ def compute_polarization( If True, the mean angle is returned in degrees. Otherwise, the angle is returned in radians. Only relevant when ``return_angle=True``. + validate_ap : bool, default=True + If True, run anterior-posterior axis validation when + ``body_axis_keypoints`` is provided. Validation is skipped for + displacement-based polarization. + ap_validation_config : dict, optional + Configuration overrides for anterior-posterior axis validation. + Passed to ``_ValidateAPConfig`` when validation is enabled. Returns ------- @@ -95,6 +357,10 @@ def compute_polarization( heading vector has zero magnitude (for example exact cancellation), the returned angle is NaN. + When ``validate_ap=True`` and ``body_axis_keypoints`` is provided, + anterior-posterior validation is run per individual and the result is + stored in ``polarization.attrs["ap_validation_result"]``. + Examples -------- Compute orientation polarization from body-axis keypoints: @@ -144,6 +410,14 @@ def compute_polarization( ... return_angle=True, ... ) + Run AP validation while computing body-axis polarization: + + >>> polarization = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... validate_ap=True, + ... ) + """ _validate_type_data_array(data) normalized_keypoints = _validate_position_data( @@ -151,7 +425,12 @@ def compute_polarization( body_axis_keypoints=body_axis_keypoints, ) + ap_validation_result = None if normalized_keypoints is not None: + if validate_ap: + ap_validation_result = _run_ap_validation( + data, normalized_keypoints, ap_validation_config + ) heading_vectors = _compute_heading_from_keypoints( data=data, body_axis_keypoints=normalized_keypoints, @@ -162,10 +441,26 @@ def compute_polarization( displacement_frames=displacement_frames, ) - heading = _select_space(heading_vectors) + polarization = _compute_polarization_from_headings(heading_vectors) + if ap_validation_result is not None: + polarization.attrs["ap_validation_result"] = ap_validation_result + + if not return_angle: + return polarization + + mean_angle = _compute_mean_angle(heading_vectors, in_degrees) + return polarization, mean_angle + + +def _compute_polarization_from_headings( + heading_vectors: xr.DataArray, +) -> xr.DataArray: + """Compute polarization magnitude from heading vectors.""" + heading = _select_space(heading_vectors) unit_headings = convert_to_unit(heading) valid_mask = ~unit_headings.isnull().any(dim="space") + vector_sum = unit_headings.sum(dim="individuals", skipna=True) sum_magnitude = compute_norm(vector_sum) n_valid = valid_mask.sum(dim="individuals") @@ -175,15 +470,23 @@ def compute_polarization( sum_magnitude / n_valid, np.nan, ).clip(min=0.0, max=1.0) - polarization = polarization.rename("polarization") + return polarization.rename("polarization") - if not return_angle: - return polarization - # Normalize vector_sum to unit vector for angle computation - mean_unit_vector = vector_sum / sum_magnitude +def _compute_mean_angle( + heading_vectors: xr.DataArray, + in_degrees: bool = False, +) -> xr.DataArray: + """Compute mean heading angle from heading vectors.""" + heading = _select_space(heading_vectors) + unit_headings = convert_to_unit(heading) + valid_mask = ~unit_headings.isnull().any(dim="space") + + vector_sum = unit_headings.sum(dim="individuals", skipna=True) + sum_magnitude = compute_norm(vector_sum) + n_valid = valid_mask.sum(dim="individuals") - # Compute angle from positive x-axis to mean unit vector + mean_unit_vector = vector_sum / sum_magnitude reference = np.array([1, 0]) angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) mean_angle = xr.where( @@ -195,9 +498,65 @@ def compute_polarization( ) if in_degrees: mean_angle = np.rad2deg(mean_angle) - mean_angle = mean_angle.rename("mean_angle") + return mean_angle.rename("mean_angle") + +def _run_ap_validation( + data: xr.DataArray, + normalized_keypoints: tuple[Hashable, Hashable], + ap_validation_config: dict[str, Any] | None, +) -> dict: + """Run AP validation across all individuals, select best by R*M. + + Each individual is validated independently using the supplied keypoint + pair. R*M (resultant_length × vote_margin) is computed per individual + and depends only on the individual's motion and body shape, not on + the input pair. The best individual is the one with the highest R*M. + """ + config = ( + _ValidateAPConfig(**ap_validation_config) + if ap_validation_config is not None + else None + ) + + if "individuals" not in data.dims: + single_result = _validate_ap( + data, + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + return {"all_results": [single_result], "best_idx": 0} + + individuals = list(data.coords["individuals"].values) + all_results = [] + for individual in individuals: + result = _validate_ap( + data.sel(individuals=individual), + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + result["individual"] = individual + all_results.append(result) + + best_idx = _find_best_individual_by_rxm(all_results) + return {"all_results": all_results, "best_idx": best_idx} - return polarization, mean_angle + +def _find_best_individual_by_rxm(all_results: list[dict]) -> int: + """Return index of the individual with highest R*M score.""" + best_idx = -1 + best_rxm = -1.0 + for i, result in enumerate(all_results): + if not result["success"]: + continue + rxm = result["resultant_length"] * result["vote_margin"] + if rxm > best_rxm: + best_rxm = rxm + best_idx = i + return best_idx def _compute_heading_from_keypoints( @@ -342,3 +701,2547 @@ def _validate_type_data_array(data: xr.DataArray) -> None: raise TypeError( f"Input data must be an xarray.DataArray, but got {type(data)}." ) + + +# Helper functions for _validate_ap + + +def _compute_tiered_validity( + keypoints: np.ndarray, + min_valid_frac: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute tiered validity masks for each frame. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + min_valid_frac : float + Minimum fraction of keypoints required for tier-1 validity. + + Returns + ------- + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + A frame is tier-1 valid if at least min_valid_frac of keypoints + are present AND at least 2 keypoints are present. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + A frame is tier-2 valid if all keypoints are present. + frac_present : np.ndarray + Array of shape (n_frames,) with fraction of keypoints present. + + """ + n_frames, n_keypoints, _ = keypoints.shape + + # A keypoint is present if neither x nor y is NaN + # Shape: (n_frames, n_keypoints) + keypoint_present = ~np.any(np.isnan(keypoints), axis=2) + + # Count present keypoints per frame + n_present = np.sum(keypoint_present, axis=1) + frac_present = n_present / n_keypoints + + # Tier-2: all keypoints present + tier2_valid = n_present == n_keypoints + + # Tier-1: at least min_valid_frac present AND at least 2 present + tier1_valid = (frac_present >= min_valid_frac) & (n_present >= 2) + + return tier1_valid, tier2_valid, frac_present + + +def _compute_bbox_centroid( + keypoints: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute bounding-box centroids for tier-1 valid frames. + + The bounding-box centroid is the midpoint of the axis-aligned bounding + box enclosing all present keypoints. This is density-invariant, unlike + the arithmetic mean. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + NaN for non-tier-1-valid frames. + arith_centroids : np.ndarray + Array of shape (n_frames, 2) with arithmetic-mean centroids. + NaN for non-tier-1-valid frames. Used for diagnostic comparison. + centroid_discrepancy : np.ndarray + Array of shape (n_frames,) with normalized discrepancy between + bbox and arithmetic centroids (distance / bbox_diagonal). + NaN for non-tier-1-valid frames. + + """ + n_frames = keypoints.shape[0] + + bbox_centroids = np.full((n_frames, 2), np.nan) + arith_centroids = np.full((n_frames, 2), np.nan) + centroid_discrepancy = np.full(n_frames, np.nan) + + for f in range(n_frames): + if not tier1_valid[f]: + continue + + kp_f = keypoints[f] # (n_keypoints, 2) + + # Find present keypoints (no NaN in either coordinate) + present_mask = ~np.any(np.isnan(kp_f), axis=1) + kp_present = kp_f[present_mask] + + # Bounding-box centroid + bbox_min = np.min(kp_present, axis=0) + bbox_max = np.max(kp_present, axis=0) + bbox_centroids[f] = (bbox_min + bbox_max) / 2 + + # Arithmetic-mean centroid + arith_centroids[f] = np.mean(kp_present, axis=0) + + # Centroid discrepancy: distance normalized by bbox diagonal + bbox_diag = np.linalg.norm(bbox_max - bbox_min) + if bbox_diag > 0: + discrepancy = np.linalg.norm( + bbox_centroids[f] - arith_centroids[f] + ) + centroid_discrepancy[f] = discrepancy / bbox_diag + else: + centroid_discrepancy[f] = 0.0 + + return bbox_centroids, arith_centroids, centroid_discrepancy + + +def _compute_frame_velocities( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute frame-to-frame centroid velocities and speeds. + + A velocity is valid only when both adjacent frames are tier-1 valid. + + Parameters + ---------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + velocities : np.ndarray + Array of shape (n_frames - 1, 2) with velocity vectors. + Invalid velocities are NaN. + speeds : np.ndarray + Array of shape (n_frames - 1,) with speed scalars. + Invalid speeds are NaN. + + """ + velocities = np.diff(bbox_centroids, axis=0) # (n_frames - 1, 2) + + # Velocity valid only if both adjacent frames are tier-1 valid + speed_valid = tier1_valid[:-1] & tier1_valid[1:] + + # Mask invalid velocities + velocities[~speed_valid] = np.nan + + speeds = np.linalg.norm(velocities, axis=1) + + return velocities, speeds + + +def _compute_sliding_window_medians( + speeds: np.ndarray, + window_len: int, + stride: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute median speeds for sliding windows. + + A window is valid only when every speed sample in that window is valid + (non-NaN), ensuring strict NaN-free content. + + Parameters + ---------- + speeds : np.ndarray + Array of shape (n_speed_samples,) with speed values. + window_len : int + Number of speed samples per sliding window. + stride : int + Step size between consecutive window start positions. + + Returns + ------- + window_starts : np.ndarray + Array of window start indices (0-indexed). + window_medians : np.ndarray + Median speed for each window. NaN for invalid windows. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + + """ + num_speed = len(speeds) + window_starts = np.arange(0, num_speed - window_len + 1, stride) + num_windows = len(window_starts) + + window_medians = np.full(num_windows, np.nan) + window_all_valid = np.zeros(num_windows, dtype=bool) + + for k in range(num_windows): + s = window_starts[k] + e = s + window_len + w = speeds[s:e] + + # Window valid only if all samples are non-NaN + if np.all(~np.isnan(w)): + window_all_valid[k] = True + window_medians[k] = np.median(w) + + return window_starts, window_medians, window_all_valid + + +def _detect_high_motion_windows( + window_medians: np.ndarray, + window_all_valid: np.ndarray, + pct_thresh: float, +) -> np.ndarray: + """Identify high-motion windows based on percentile threshold. + + Parameters + ---------- + window_medians : np.ndarray + Median speed for each window. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + pct_thresh : float + Percentile threshold (0-100) for high-motion classification. + + Returns + ------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + + """ + valid_medians = window_medians[window_all_valid] + if len(valid_medians) == 0: + return np.zeros(len(window_medians), dtype=bool) + + thresh = np.percentile(valid_medians, pct_thresh) + high_motion = window_all_valid & (window_medians >= thresh) + + return high_motion + + +def _detect_runs( + high_motion: np.ndarray, + min_run_len: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Detect runs of consecutive high-motion windows. + + A run is a maximal sequence of consecutively indexed qualifying windows. + + Parameters + ---------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + min_run_len : int + Minimum number of consecutive qualifying windows for a valid run. + + Returns + ------- + run_starts : np.ndarray + Start indices of valid runs. + run_ends : np.ndarray + End indices (inclusive) of valid runs. + run_lengths : np.ndarray + Length of each valid run. + + """ + padded = np.concatenate([[False], high_motion, [False]]) + d = np.diff(padded.astype(int)) + + # Find run boundaries + run_starts_all = np.where(d == 1)[0] + run_ends_all = np.where(d == -1)[0] - 1 + run_lengths_all = run_ends_all - run_starts_all + 1 + + # Filter by minimum run length + valid_mask = run_lengths_all >= min_run_len + run_starts = run_starts_all[valid_mask] + run_ends = run_ends_all[valid_mask] + run_lengths = run_lengths_all[valid_mask] + + return run_starts, run_ends, run_lengths + + +def _convert_runs_to_segments( + run_starts: np.ndarray, + run_ends: np.ndarray, + window_starts: np.ndarray, + window_len: int, +) -> np.ndarray: + """Convert window runs to frame segments. + + Each run is converted to a frame interval spanning from the start frame + of the first window to the end frame of the last window. + + Parameters + ---------- + run_starts : np.ndarray + Start indices of valid runs (indices into window arrays). + run_ends : np.ndarray + End indices (inclusive) of valid runs. + window_starts : np.ndarray + Start frame indices for each window. + window_len : int + Length of each window in frames. + + Returns + ------- + segments_raw : np.ndarray + Array of shape (n_runs, 2) with [frame_start, frame_end] for each run. + + """ + n_runs = len(run_starts) + segments_raw = np.zeros((n_runs, 2), dtype=int) + + for j in range(n_runs): + s_idx = run_starts[j] + e_idx = run_ends[j] + frame_start = window_starts[s_idx] + frame_end = window_starts[e_idx] + window_len + segments_raw[j] = [frame_start, frame_end] + + return segments_raw + + +def _merge_segments(segments_raw: np.ndarray) -> np.ndarray: + """Merge overlapping or abutting frame segments. + + Segments are first sorted by start frame, then merged if they overlap + or abut (next start <= current end + 1). + + Parameters + ---------- + segments_raw : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + + Returns + ------- + segments : np.ndarray + Array of merged non-overlapping segments. + + """ + if len(segments_raw) == 0: + return segments_raw + + # Sort by start frame + sorted_idx = np.argsort(segments_raw[:, 0]) + segments_sorted = segments_raw[sorted_idx] + + merged = [segments_sorted[0].tolist()] + + for j in range(1, len(segments_sorted)): + next_seg = segments_sorted[j] + curr_seg = merged[-1] + + # Merge if overlapping or abutting + if next_seg[0] <= curr_seg[1] + 1: + merged[-1][1] = max(curr_seg[1], next_seg[1]) + else: + merged.append(next_seg.tolist()) + + return np.array(merged, dtype=int) + + +def _filter_segments_tier2( + segments: np.ndarray, + tier2_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Filter segment frames to retain only tier-2 valid frames. + + Parameters + ---------- + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + + Returns + ------- + selected_frames : np.ndarray + Array of tier-2 valid frame indices within segments. + selected_seg_id : np.ndarray + Segment ID (0-indexed) for each selected frame. + + """ + # Collect all unique frames from all segments + all_segment_frames: list[int] = [] + for k in range(len(segments)): + frame_start, frame_end = segments[k] + seg_frames = np.arange(frame_start, frame_end + 1) + all_segment_frames.extend(seg_frames) + + segment_frames_all = np.unique(all_segment_frames) + + tier2_mask = tier2_valid[segment_frames_all] + selected_frames = segment_frames_all[tier2_mask] + + # Assign each selected frame to its segment + num_selected = len(selected_frames) + selected_seg_id = np.zeros(num_selected, dtype=int) + + for j in range(num_selected): + f = selected_frames[j] + for k in range(len(segments)): + if segments[k, 0] <= f <= segments[k, 1]: + selected_seg_id[j] = k + break + + return selected_frames, selected_seg_id + + +def _build_centered_skeletons( + keypoints: np.ndarray, + selected_frames: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Build centroid-centered skeletons for selected frames. + + Uses bounding-box centroid for centering, consistent with the + segmentation centroid. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + selected_frames : np.ndarray + Array of selected frame indices. + + Returns + ------- + selected_centroids : np.ndarray + Array of shape (num_selected, 2) with bounding-box centroids. + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2) with + centroid-centered skeleton coordinates. + + """ + num_selected = len(selected_frames) + n_keypoints = keypoints.shape[1] + + selected_centroids = np.zeros((num_selected, 2)) + centered_skeletons = np.zeros((num_selected, n_keypoints, 2)) + + for j in range(num_selected): + f = selected_frames[j] + kp_f = keypoints[f] # (n_keypoints, 2) - all present for tier-2 + + # Bounding-box centroid + bbox_min = np.min(kp_f, axis=0) + bbox_max = np.max(kp_f, axis=0) + centroid_f = (bbox_min + bbox_max) / 2 + + selected_centroids[j] = centroid_f + centered_skeletons[j] = kp_f - centroid_f + + return selected_centroids, centered_skeletons + + +def _compute_pairwise_rmsd(centered_skeletons: np.ndarray) -> np.ndarray: + """Compute pairwise RMSD between all centered skeletons. + + RMSD is computed as the square root of the mean of squared entry-wise + differences between flattened skeleton vectors. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + + Returns + ------- + rmsd_matrix : np.ndarray + Symmetric matrix of shape (num_selected, num_selected) with + pairwise RMSD values. Diagonal is zero. + + """ + num_selected = len(centered_skeletons) + + skel_flat = centered_skeletons.reshape(num_selected, -1) + + rmsd_matrix = np.zeros((num_selected, num_selected)) + + for i in range(num_selected): + for j in range(i + 1, num_selected): + d = skel_flat[i] - skel_flat[j] + rmsd_val = np.sqrt(np.mean(d**2)) + rmsd_matrix[i, j] = rmsd_val + rmsd_matrix[j, i] = rmsd_val + + return rmsd_matrix + + +def _compute_postural_variance_ratio( + rmsd_matrix: np.ndarray, + selected_seg_id: np.ndarray, +) -> tuple[float, np.ndarray, np.ndarray, bool]: + """Compute the between/within segment RMSD variance ratio. + + Parameters + ---------- + rmsd_matrix : np.ndarray + Pairwise RMSD matrix of shape (num_selected, num_selected). + selected_seg_id : np.ndarray + Segment ID for each selected frame. + + Returns + ------- + var_ratio : float + Ratio of between-segment to within-segment RMSD variance. + Returns 0.0 if either distribution is empty or within variance is 0. + within_rmsds : np.ndarray + Array of within-segment RMSD values. + between_rmsds : np.ndarray + Array of between-segment RMSD values. + var_ratio_override : bool + True if variance ratio was set to 0 due to edge cases. + + """ + num_selected = len(selected_seg_id) + within_rmsds_list: list[float] = [] + between_rmsds_list: list[float] = [] + + for i in range(num_selected): + for j in range(i + 1, num_selected): + if selected_seg_id[i] == selected_seg_id[j]: + within_rmsds_list.append(rmsd_matrix[i, j]) + else: + between_rmsds_list.append(rmsd_matrix[i, j]) + + within_rmsds = np.array(within_rmsds_list) + between_rmsds = np.array(between_rmsds_list) + + # Compute variance ratio with edge case handling + var_ratio_override = False + if ( + len(within_rmsds) > 0 + and len(between_rmsds) > 0 + and np.var(within_rmsds) > 0 + ): + var_ratio = np.var(between_rmsds) / np.var(within_rmsds) + else: + var_ratio = 0.0 + var_ratio_override = True + + return var_ratio, within_rmsds, between_rmsds, var_ratio_override + + +def _kmedoids( + data: np.ndarray, + k: int, + max_iter: int = 100, + n_init: int = 5, + random_state: int | None = None, +) -> tuple[np.ndarray, np.ndarray, float]: + """Perform k-medoids clustering. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + k : int + Number of clusters. + max_iter : int, default=100 + Maximum number of iterations. + n_init : int, default=5 + Number of random initializations. + random_state : int, optional + Random seed for reproducibility. + + Returns + ------- + labels : np.ndarray + Cluster labels for each sample (0-indexed). + medoid_indices : np.ndarray + Indices of medoid samples. + inertia : float + Sum of distances from samples to their medoids. + + """ + from scipy.spatial.distance import cdist + + rng = np.random.default_rng(random_state) + n_samples = len(data) + + dist_matrix = cdist(data, data, metric="euclidean") + + best_labels: np.ndarray | None = None + best_medoids: np.ndarray | None = None + best_inertia = np.inf + + for _ in range(n_init): + medoids = rng.choice(n_samples, size=k, replace=False) + + for _ in range(max_iter): + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + + # Update medoids + new_medoids = np.zeros(k, dtype=int) + for cluster in range(k): + cluster_mask = labels == cluster + if not np.any(cluster_mask): + # Empty cluster - keep old medoid + new_medoids[cluster] = medoids[cluster] + continue + + cluster_indices = np.where(cluster_mask)[0] + # Find point that minimizes sum of distances within cluster + cluster_dists = dist_matrix[ + np.ix_(cluster_indices, cluster_indices) + ] + total_dists = np.sum(cluster_dists, axis=1) + best_idx = np.argmin(total_dists) + new_medoids[cluster] = cluster_indices[best_idx] + + if np.array_equal(np.sort(medoids), np.sort(new_medoids)): + break + medoids = new_medoids + + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + inertia = np.sum(distances_to_medoids[np.arange(n_samples), labels]) + + if inertia < best_inertia: + best_inertia = inertia + best_labels = labels.copy() + best_medoids = medoids.copy() + + # These are guaranteed to be set after at least one iteration + assert best_labels is not None and best_medoids is not None + return best_labels, best_medoids, best_inertia + + +def _silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: + """Compute mean silhouette score. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + labels : np.ndarray + Cluster labels for each sample. + + Returns + ------- + score : float + Mean silhouette score across all samples. + Returns 0.0 if clustering is degenerate. + + """ + from scipy.spatial.distance import cdist + + n_samples = len(data) + unique_labels = np.unique(labels) + n_clusters = len(unique_labels) + + if n_clusters <= 1 or n_clusters >= n_samples: + return 0.0 + + dist_matrix = cdist(data, data, metric="euclidean") + + silhouette_vals = np.zeros(n_samples) + + for i in range(n_samples): + own_cluster = labels[i] + own_mask = labels == own_cluster + + # a(i) = mean distance to points in same cluster + if np.sum(own_mask) > 1: + a_i = np.mean( + dist_matrix[i, own_mask & (np.arange(n_samples) != i)] + ) + else: + a_i = 0.0 + + # b(i) = min over other clusters of mean distance to that cluster + b_i = np.inf + for cluster in unique_labels: + if cluster == own_cluster: + continue + cluster_mask = labels == cluster + if np.any(cluster_mask): + mean_dist = np.mean(dist_matrix[i, cluster_mask]) + b_i = min(b_i, mean_dist) + + if b_i == np.inf: + silhouette_vals[i] = 0.0 + else: + silhouette_vals[i] = ( + (b_i - a_i) / max(a_i, b_i) if max(a_i, b_i) > 0 else 0.0 + ) + + return float(np.mean(silhouette_vals)) + + +def _perform_postural_clustering( + centered_skeletons: np.ndarray, + max_clusters: int, + min_silhouette: float = 0.2, +) -> tuple[np.ndarray, int, int, float, list[tuple[int, float]]]: + """Perform postural clustering using k-medoids with silhouette selection. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + max_clusters : int + Maximum number of clusters to evaluate. + min_silhouette : float, default=0.2 + Minimum silhouette score to accept clustering. + + Returns + ------- + cluster_labels : np.ndarray + Cluster labels for each frame (0-indexed). + num_clusters : int + Number of clusters (1 if clustering not accepted). + primary_cluster : int + Index of largest cluster (0-indexed). + best_silhouette : float + Best silhouette score achieved. + silhouette_scores : list of (k, score) + Silhouette scores for each k evaluated. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + + best_k = 1 + best_sil = -np.inf + silhouette_scores = [] + + # Evaluate k from 2 to max_clusters (capped at num_selected // 2) + max_k = min(max_clusters, num_selected // 2) + + for k in range(2, max_k + 1): + try: + labels, _, _ = _kmedoids(skel_flat, k, n_init=5) + sil = _silhouette_score(skel_flat, labels) + silhouette_scores.append((k, sil)) + + if sil > best_sil: + best_sil = sil + best_k = k + except Exception: + silhouette_scores.append((k, np.nan)) + + # Accept clustering only if best_sil > min_silhouette + if best_k > 1 and best_sil > min_silhouette: + # Re-run with more initializations for final result + cluster_labels, _, _ = _kmedoids(skel_flat, best_k, n_init=10) + num_clusters = best_k + + # Primary cluster = largest + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + primary_cluster = int(np.argmax(cluster_counts)) + else: + cluster_labels = np.zeros(num_selected, dtype=int) + num_clusters = 1 + primary_cluster = 0 + + return ( + cluster_labels, + num_clusters, + primary_cluster, + best_sil, + silhouette_scores, + ) + + +def _compute_cluster_velocities( + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + cluster_mask: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> np.ndarray: + """Compute velocities between adjacent consecutive frames. + + Only considers frames in the same segment and cluster. Frame pairs + where both frames are consecutive (frame[i] == frame[i-1] + 1), + in the same segment, and in the same cluster contribute a velocity + vector. + This prevents spanning temporal gaps or mixing postures across clusters. + + Returns + ------- + np.ndarray + Array of shape (n_velocities, 2). Empty (0, 2) if no valid pairs. + + """ + frames_c = selected_frames[cluster_mask] + seg_ids_c = selected_seg_id[cluster_mask] + velocities_list: list[np.ndarray] = [] + + for seg_k in range(len(segments)): + seg_mask = seg_ids_c == seg_k + seg_frames = np.sort(frames_c[seg_mask]) + for fi in range(1, len(seg_frames)): + if seg_frames[fi] != seg_frames[fi - 1] + 1: + continue + curr_frame = seg_frames[fi] + prev_frame = seg_frames[fi - 1] + v = bbox_centroids[curr_frame] - bbox_centroids[prev_frame] + if np.all(~np.isnan(v)): + velocities_list.append(v) + + return np.array(velocities_list) if velocities_list else np.zeros((0, 2)) + + +def _infer_anterior_from_velocities( + velocities: np.ndarray, + PC1: np.ndarray, +) -> dict: + """Infer anterior direction from velocity projections onto PC1. + + Uses strict majority vote on PC1 projection signs: anterior = +PC1 + if n_positive > n_negative, else −PC1 (ties default to −PC1). + + Also computes circular statistics on velocity angles: + - resultant_length R = √(C² + S²) where C = mean(cos θ), S = mean(sin θ) + - vote_margin M = |n₊ − n₋| / (n₊ + n₋) + + Returns dict with resultant_length, circ_mean_dir, vel_projs_pc1, + num_positive, num_negative, vote_margin, anterior_sign. + """ + result: dict = { + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "vel_projs_pc1": np.array([]), + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "anterior_sign": -1, + } + if len(velocities) == 0: + return result + + vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0]) + circ_C = np.mean(np.cos(vel_angles)) + circ_S = np.mean(np.sin(vel_angles)) + result["resultant_length"] = np.sqrt(circ_C**2 + circ_S**2) + result["circ_mean_dir"] = np.arctan2(circ_S, circ_C) + + vel_projs = velocities @ PC1 + num_pos = int(np.sum(vel_projs > 0)) + num_neg = int(np.sum(vel_projs < 0)) + result["vel_projs_pc1"] = vel_projs + result["num_positive"] = num_pos + result["num_negative"] = num_neg + result["vote_margin"] = abs(num_pos - num_neg) / max(num_pos + num_neg, 1) + result["anterior_sign"] = +1 if num_pos > num_neg else -1 + return result + + +def _compute_cluster_pca_and_anterior( + centered_skeletons: np.ndarray, + cluster_mask: np.ndarray, + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> dict: + """Compute SVD-based PCA and velocity-based anterior inference. + + Performs inference for one cluster. + + Performs SVD on the cluster's average centered skeleton to extract PC1/PC2, + applies the geometric sign convention, then infers the anterior direction + via velocity voting on centroid displacements projected onto PC1. + + Returns + ------- + dict + Keys: valid, n_frames, avg_skeleton, valid_shape_rows, + PC1, PC2, anterior_sign, vote_margin, resultant_length, + circ_mean_dir, velocities, vel_projs_pc1, and others. + + """ + n_keypoints = centered_skeletons.shape[1] + n_c = int(np.sum(cluster_mask)) + + result: dict = { + "valid": False, + "n_frames": n_c, + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "valid_shape_rows": np.zeros(n_keypoints, dtype=bool), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "proj_pc1": np.full(n_keypoints, np.nan), + "proj_pc2": np.full(n_keypoints, np.nan), + "anterior_sign": -1, + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "velocities": np.zeros((0, 2)), + "vel_projs_pc1": np.array([]), + } + + if n_c == 0: + return result + + skels_c = centered_skeletons[cluster_mask] + avg_skel_c = np.mean(skels_c, axis=0) + valid_shape_rows = ~np.any(np.isnan(avg_skel_c), axis=1) + + if np.sum(valid_shape_rows) < 2: + return result + + result["avg_skeleton"] = avg_skel_c + result["valid_shape_rows"] = valid_shape_rows + + # SVD on valid (non-NaN) rows of the average centered skeleton + D_valid = avg_skel_c[valid_shape_rows] + _U, _S, Vt = np.linalg.svd(D_valid, full_matrices=False) + PC1 = Vt[0] + PC2 = Vt[1] if len(Vt) > 1 else np.array([0.0, 1.0]) + + # Geometric sign convention (reproducible across runs, decoupled + # from anatomical AP assignment which is determined by velocity voting): + # PC1 flipped so y-component >= 0 + # PC2 flipped so x-component >= 0 + + if PC1[1] < 0: + PC1 = -PC1 + if PC2[0] < 0: + PC2 = -PC2 + + result["PC1"] = PC1 + result["PC2"] = PC2 + + proj_pc1 = np.full(n_keypoints, np.nan) + proj_pc2 = np.full(n_keypoints, np.nan) + proj_pc1[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC1 + proj_pc2[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC2 + result["proj_pc1"] = proj_pc1 + result["proj_pc2"] = proj_pc2 + + velocities = _compute_cluster_velocities( + selected_frames, + selected_seg_id, + cluster_mask, + segments, + bbox_centroids, + ) + result["velocities"] = velocities + result.update(_infer_anterior_from_velocities(velocities, PC1)) + result["valid"] = True + return result + + +def _compute_node_projections( + report: _APNodePairReport, + avg_skeleton: np.ndarray, + PC1: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, +) -> None: + """Compute raw PC1, AP-oriented, and lateral projections. + + Computes projections for all valid keypoints. + + Populates the report's coordinate arrays and determines: + - pc1_coords: raw projection onto PC1 (sign-convention only) + - ap_coords: projection onto anterior_sign × PC1 (positive = more + anterior) + - lateral_offsets: unsigned distance from the AP axis + - midpoint_pc1: average of min/max PC1 projections (AP reference point) + - input_pair_order_matches_inference: True if from_node's AP coord < + to_node's + """ + pc1 = PC1 / np.linalg.norm(PC1) + # AP unit vector: anterior_sign * PC1, so positive projection = anterior + e_ap = anterior_sign * pc1 + # Lateral unit vector: 90° CCW rotation of e_ap + e_lat = np.array([-e_ap[1], e_ap[0]]) + + D_valid = avg_skeleton[valid_shape_rows] + report.pc1_coords[valid_shape_rows] = D_valid @ pc1 + report.ap_coords[valid_shape_rows] = D_valid @ e_ap + report.lateral_offsets[valid_shape_rows] = np.abs(D_valid @ e_lat) + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_order_matches_inference = ( + report.ap_coords[from_node] < report.ap_coords[to_node] + ) + + proj_pc1_valid = report.pc1_coords[valid_shape_rows] + report.pc1_min = float(np.min(proj_pc1_valid)) + report.pc1_max = float(np.max(proj_pc1_valid)) + report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 + + +def _apply_lateral_filter( + report: _APNodePairReport, + valid_idx: np.ndarray, + lateral_thresh: float, +) -> np.ndarray | None: + """Step 1: Filter keypoints by normalized lateral offset. + + Returns sorted candidate node indices, or None on failure. + """ + d_valid = report.lateral_offsets[valid_idx] + d_min = float(np.min(d_valid)) + d_max = float(np.max(d_valid)) + report.lateral_offset_min = d_min + report.lateral_offset_max = d_max + + if d_max > d_min: + d_norm = (d_valid - d_min) / (d_max - d_min) + report.lateral_offsets_norm[valid_idx] = d_norm + keep_mask = d_norm <= lateral_thresh + else: + report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) + keep_mask = np.ones(len(d_valid), dtype=bool) + + candidate_idx = np.where(keep_mask)[0] + C = valid_idx[candidate_idx] + sorted_order = np.argsort(d_valid[candidate_idx]) + C = C[sorted_order] + report.sorted_candidate_nodes = C.copy() + + if len(C) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = ( + "Fewer than 2 candidates remained after filtering." + ) + return None + return C + + +def _find_opposite_side_pairs( + report: _APNodePairReport, + C: np.ndarray, + from_node: int, + to_node: int, + valid_shape_rows: np.ndarray, +) -> tuple[np.ndarray, np.ndarray] | None: + """Step 2: Find candidate pairs on opposite sides of the AP midpoint. + + Returns (pairs, seps) arrays, or None on failure. + """ + m = report.midpoint_pc1 + report.input_pair_in_candidates = (from_node in C) and (to_node in C) + + pairs_list: list[list[int]] = [] + seps_list: list[float] = [] + for ii in range(len(C)): + for jj in range(ii + 1, len(C)): + i, j = C[ii], C[jj] + if (report.pc1_coords[i] - m) * (report.pc1_coords[j] - m) < 0: + pairs_list.append([i, j]) + seps_list.append( + abs(report.ap_coords[i] - report.ap_coords[j]) + ) + + pairs = ( + np.array(pairs_list, dtype=int) + if pairs_list + else np.zeros((0, 2), dtype=int) + ) + seps = np.array(seps_list) if seps_list else np.array([]) + report.valid_pairs = pairs + report.valid_pairs_internode_dist = seps + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_opposite_sides = ( + (report.pc1_coords[from_node] - m) + * (report.pc1_coords[to_node] - m) + ) < 0 + report.input_pair_separation_abs = abs( + report.ap_coords[from_node] - report.ap_coords[to_node] + ) + + if len(pairs) == 0: + report.failure_step = "Step 2: opposite-sides constraint" + report.failure_reason = ( + "No candidate pair lies on opposite sides of the midpoint." + ) + return None + return pairs, seps + + +def _order_pair_by_ap( + pair: np.ndarray, + ap_coords: np.ndarray, +) -> np.ndarray: + """Order a node pair so element 0 is posterior (lower AP coord). + + This ensures that suggested pairs always encode the + posterior→anterior direction, matching the convention used by + ``body_axis_keypoints=(from_node, to_node)`` where from_node is + posterior and to_node is anterior. + + Parameters + ---------- + pair : np.ndarray + Two-element array of node indices. + ap_coords : np.ndarray + AP coordinates for all keypoints (anterior_sign already applied). + + Returns + ------- + np.ndarray + The same two indices, ordered so that + ``ap_coords[result[0]] <= ap_coords[result[1]]``. + + """ + i, j = pair + if ap_coords[i] <= ap_coords[j]: + return np.array([i, j], dtype=int) + return np.array([j, i], dtype=int) + + +def _classify_distal_proximal( + report: _APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + valid_shape_rows: np.ndarray, + edge_thresh: float, +) -> np.ndarray: + """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" + m = report.midpoint_pc1 + midline_dist = np.abs(report.pc1_coords - m) + d_max_midline = float(np.nanmax(midline_dist[valid_shape_rows])) + report.midline_dist_max = d_max_midline + + if d_max_midline > 0: + report.midline_dist_norm = midline_dist / d_max_midline + else: + report.midline_dist_norm = np.zeros(len(report.pc1_coords)) + + pair_is_distal = np.zeros(len(pairs), dtype=bool) + for k in range(len(pairs)): + i, j = pairs[k] + pair_is_distal[k] = ( + min(report.midline_dist_norm[i], report.midline_dist_norm[j]) + >= edge_thresh + ) + + report.distal_pairs = pairs[pair_is_distal] + report.proximal_pairs = pairs[~pair_is_distal] + + if len(seps) > 0: + idx_max = int(np.argmax(seps)) + report.max_separation_nodes = _order_pair_by_ap( + pairs[idx_max], report.ap_coords + ) + report.max_separation = seps[idx_max] + + if np.any(pair_is_distal): + distal_seps = seps[pair_is_distal] + distal_pairs_only = pairs[pair_is_distal] + idx_max_distal = int(np.argmax(distal_seps)) + report.max_separation_distal_nodes = _order_pair_by_ap( + distal_pairs_only[idx_max_distal], report.ap_coords + ) + report.max_separation_distal = distal_seps[idx_max_distal] + + return pair_is_distal + + +def _check_input_pair_in_valid( + report: _APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + from_node: int, + to_node: int, +) -> tuple[bool, int]: + """Check whether input pair is among valid pairs. Returns (found, idx).""" + input_pair_sorted = tuple(sorted([from_node, to_node])) + input_in_valid = False + input_idx = -1 + + for k in range(len(pairs)): + if tuple(sorted(pairs[k])) == input_pair_sorted: + input_in_valid = True + input_idx = k + break + + if input_in_valid: + report.input_pair_is_distal = pair_is_distal[input_idx] + rank_order = np.argsort(seps)[::-1] + report.input_pair_rank = ( + int(np.where(rank_order == input_idx)[0][0]) + 1 + ) + return input_in_valid, input_idx + + +def _evaluate_ap_node_pair( + avg_skeleton: np.ndarray, + PC1: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, + config: _ValidateAPConfig, +) -> _APNodePairReport: + """Evaluate an AP node pair through the 3-step filter cascade. + + Parameters + ---------- + avg_skeleton : np.ndarray + Average centered skeleton of shape (n_keypoints, 2). + PC1 : np.ndarray + First principal component vector of shape (2,). + anterior_sign : int + Inferred anterior direction (+1 or -1 relative to PC1). + valid_shape_rows : np.ndarray + Boolean array indicating valid (non-NaN) keypoints. + from_node : int + Index of the input from_node (body_axis_keypoints origin, + claimed posterior). 0-indexed. + to_node : int + Index of the input to_node (body_axis_keypoints target, + claimed anterior). 0-indexed. + config : _ValidateAPConfig + Configuration with ``lateral_thresh`` and ``edge_thresh``. + + Returns + ------- + _APNodePairReport + Complete evaluation report. + + """ + n_keypoints = len(avg_skeleton) + report = _APNodePairReport() + report.pc1_coords = np.full(n_keypoints, np.nan) + report.ap_coords = np.full(n_keypoints, np.nan) + report.lateral_offsets = np.full(n_keypoints, np.nan) + report.lateral_offsets_norm = np.full(n_keypoints, np.nan) + report.midline_dist_norm = np.full(n_keypoints, np.nan) + + for node, label in [(from_node, "from_node"), (to_node, "to_node")]: + if node < 0 or node >= n_keypoints: + report.failure_step = "Input validation" + report.failure_reason = ( + f"{label} must be a valid index in 0..{n_keypoints - 1}." + ) + return report + + valid_idx = np.where(valid_shape_rows)[0] + if len(valid_idx) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = "Fewer than 2 valid nodes are available." + return report + + _compute_node_projections( + report, + avg_skeleton, + PC1, + anterior_sign, + valid_shape_rows, + from_node, + to_node, + ) + + C = _apply_lateral_filter(report, valid_idx, config.lateral_thresh) + if C is None: + return report + + step2 = _find_opposite_side_pairs( + report, + C, + from_node, + to_node, + valid_shape_rows, + ) + if step2 is None: + return report + pairs, seps = step2 + + pair_is_distal = _classify_distal_proximal( + report, + pairs, + seps, + valid_shape_rows, + config.edge_thresh, + ) + + input_in_valid, input_idx = _check_input_pair_in_valid( + report, + pairs, + seps, + pair_is_distal, + from_node, + to_node, + ) + + report = _assign_scenario( + report, pairs, seps, pair_is_distal, input_in_valid, input_idx + ) + report.success = True + return report + + +def _assign_single_pair_scenario( + report: _APNodePairReport, + pairs: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, +) -> _APNodePairReport: + """Assign scenario when exactly one valid pair exists (scenarios 1-4).""" + if input_in_valid: + if pair_is_distal[0]: + report.scenario = 1 + report.outcome = "accept" + else: + report.scenario = 2 + report.outcome = "warn" + report.warning_message = "Input pair has proximal node(s)." + elif pair_is_distal[0]: + report.scenario = 3 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Suggest pair [{pairs[0, 0]}, {pairs[0, 1]}]." + ) + else: + report.scenario = 4 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Only option " + f"[{pairs[0, 0]}, {pairs[0, 1]}] has proximal node(s)." + ) + return report + + +def _assign_multi_input_distal_scenario( + report: _APNodePairReport, + pairs: np.ndarray, + input_idx: int, +) -> _APNodePairReport: + """Assign scenario for distal input in multi-pair case (5, 6, 7).""" + input_pair_sorted = tuple( + sorted([pairs[input_idx, 0], pairs[input_idx, 1]]) + ) + max_distal_sorted = ( + tuple(sorted(report.max_separation_distal_nodes)) + if len(report.max_separation_distal_nodes) > 0 + else () + ) + + if report.input_pair_rank == 1: + report.scenario = 5 + report.outcome = "accept" + elif input_pair_sorted == max_distal_sorted: + report.scenario = 7 + report.outcome = "accept" + else: + report.scenario = 6 + report.outcome = "warn" + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Distal pair with greater separation exists: [{d[0]}, {d[1]}]." + ) + return report + + +def _assign_multi_input_proximal_scenario( + report: _APNodePairReport, + pair_is_distal: np.ndarray, +) -> _APNodePairReport: + """Assign scenario for proximal input in multi-pair case (8-11).""" + has_distal = np.any(pair_is_distal) + is_max_sep = report.input_pair_rank == 1 + + if is_max_sep and has_distal: + report.scenario = 8 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal alternative: [{d[0]}, {d[1]}]." + ) + elif is_max_sep: + report.scenario = 9 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + elif has_distal: + report.scenario = 10 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal pair with greater separation: [{d[0]}, {d[1]}]." + ) + else: + report.scenario = 11 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + + report.outcome = "warn" + return report + + +def _assign_multi_input_invalid_scenario( + report: _APNodePairReport, + pair_is_distal: np.ndarray, +) -> _APNodePairReport: + """Assign scenario when input not in valid pairs (12-13).""" + has_distal = np.any(pair_is_distal) + report.outcome = "warn" + + if has_distal: + report.scenario = 12 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input invalid. Suggest max separation distal pair: " + f"[{d[0]}, {d[1]}]." + ) + else: + report.scenario = 13 + m = report.max_separation_nodes + report.warning_message = ( + f"Input invalid. All pairs have proximal node(s). " + f"Max separation: [{m[0]}, {m[1]}]." + ) + return report + + +def _assign_scenario( + report: _APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, + input_idx: int, +) -> _APNodePairReport: + """Assign one of 13 mutually exclusive scenarios. + + Parameters + ---------- + report : _APNodePairReport + The report to update with scenario information. + pairs : np.ndarray + Valid pairs array of shape (n_pairs, 2). + seps : np.ndarray + Internode separations for each pair. + pair_is_distal : np.ndarray + Boolean array indicating distal pairs. + input_in_valid : bool + Whether input pair is among valid pairs. + input_idx : int + Index of input pair in valid pairs (-1 if not present). + + Returns + ------- + _APNodePairReport + Updated report with scenario, outcome, and warning_message. + + """ + if len(pairs) == 1: + return _assign_single_pair_scenario( + report, + pairs, + pair_is_distal, + input_in_valid, + ) + + if not input_in_valid: + return _assign_multi_input_invalid_scenario(report, pair_is_distal) + + if report.input_pair_is_distal: + return _assign_multi_input_distal_scenario( + report, + pairs, + input_idx, + ) + + return _assign_multi_input_proximal_scenario(report, pair_is_distal) + + +# ── _validate_ap helper functions ──────────────────────────────────────── + + +def _resolve_node_index(node: Hashable, names: list) -> int: + """Resolve a node identifier to an integer index.""" + if isinstance(node, str): + if node in names: + return names.index(node) + raise ValueError(f"Keypoint '{node}' not found in {names}.") + if isinstance(node, int): + return node + return int(node) # type: ignore[call-overload] + + +def _prepare_validation_inputs( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, +) -> tuple[np.ndarray, int, int, str, str, list[str], int]: + """Validate inputs and extract numpy arrays for AP validation. + + Returns + ------- + tuple + (keypoints, from_idx, to_idx, from_name, to_name, + keypoint_names, num_frames) + + Raises + ------ + TypeError + If data is not an xarray.DataArray. + ValueError + If dimensions or indices are invalid. + + """ + _validate_type_data_array(data) + + required_dims = {"time", "space", "keypoints"} + if not required_dims.issubset(set(data.dims)): + raise ValueError( + f"data must have dimensions {required_dims}, " + f"but has {set(data.dims)}." + ) + + if "individuals" in data.dims: + if data.sizes["individuals"] != 1: + raise ValueError( + "data must be for a single individual. " + "Use data.sel(individuals='name') to select one." + ) + data = data.squeeze("individuals", drop=True) + + if "keypoints" in data.coords: + keypoint_names = list(data.coords["keypoints"].values) + else: + keypoint_names = [f"node_{i}" for i in range(data.sizes["keypoints"])] + + n_keypoints = data.sizes["keypoints"] + from_idx = _resolve_node_index(from_node, keypoint_names) + to_idx = _resolve_node_index(to_node, keypoint_names) + + if from_idx < 0 or from_idx >= n_keypoints: + raise ValueError( + f"from_node index {from_idx} out of range [0, {n_keypoints - 1}]." + ) + if to_idx < 0 or to_idx >= n_keypoints: + raise ValueError( + f"to_node index {to_idx} out of range [0, {n_keypoints - 1}]." + ) + + data_xy = data.sel(space=["x", "y"]) + keypoints = data_xy.transpose("time", "keypoints", "space").values + + from_name = keypoint_names[from_idx] + to_name = keypoint_names[to_idx] + num_frames = keypoints.shape[0] + + return ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + keypoint_names, + num_frames, + ) + + +def _run_motion_segmentation( + keypoints: np.ndarray, + num_frames: int, + config: _ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run tiered validity through segment detection. + + Returns a dict with tier1_valid, tier2_valid, bbox_centroids, + segments, or None on failure (error logged). + """ + tier1_valid, tier2_valid, _frac = _compute_tiered_validity( + keypoints, config.min_valid_frac + ) + num_tier1 = int(np.sum(tier1_valid)) + num_tier2 = int(np.sum(tier2_valid)) + + log_info("────────────────────────────────────────────────────────────") + log_info("Tiered Validity Report") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Tier 1 (>= %.0f%% keypoints): %d / %d frames (%.2f%%)", + config.min_valid_frac * 100, + num_tier1, + num_frames, + 100 * num_tier1 / num_frames, + ) + log_info( + "Tier 2 (100%% keypoints): %d / %d frames (%.2f%%)", + num_tier2, + num_frames, + 100 * num_tier2 / num_frames, + ) + + if num_tier1 < 2: + logger.error("Not enough tier-1 valid frames.") + return None + + bbox_centroids, _arith, centroid_disc = _compute_bbox_centroid( + keypoints, tier1_valid + ) + valid_disc = centroid_disc[tier1_valid & ~np.isnan(centroid_disc)] + if len(valid_disc) > 0: + log_info("") + log_info( + "────────────────────────────────────────────────────────────" + ) + log_info("Centroid Discrepancy Diagnostic") + log_info( + "────────────────────────────────────────────────────────────" + ) + log_info("BBox vs arithmetic centroid (normalized by bbox diagonal):") + log_info( + " Median: %.4f | Mean: %.4f | Max: %.4f", + np.median(valid_disc), + np.mean(valid_disc), + np.max(valid_disc), + ) + if np.median(valid_disc) > 0.05: + log_warning( + "Median discrepancy > 5%% - annotation density " + "is likely asymmetric." + ) + + segments = _detect_motion_segments( + bbox_centroids, tier1_valid, config, log_info, log_warning + ) + if segments is None: + return None + + return { + "tier1_valid": tier1_valid, + "tier2_valid": tier2_valid, + "bbox_centroids": bbox_centroids, + "segments": segments, + } + + +def _detect_motion_segments( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, + config: _ValidateAPConfig, + log_info, + log_warning, +) -> np.ndarray | None: + """Detect high-motion segments from centroid velocities. + + Returns merged segments array, or None on failure. + """ + velocities, speeds = _compute_frame_velocities(bbox_centroids, tier1_valid) + num_speed = len(speeds) + + if num_speed < config.window_len: + logger.error( + "window_len=%d exceeds available speed samples=%d.", + config.window_len, + num_speed, + ) + return None + + window_starts, window_medians, window_all_valid = ( + _compute_sliding_window_medians( + speeds, config.window_len, config.stride + ) + ) + num_valid_windows = int(np.sum(window_all_valid)) + if num_valid_windows == 0: + logger.error("No fully valid sliding windows found.") + return None + + high_motion = _detect_high_motion_windows( + window_medians, window_all_valid, config.pct_thresh + ) + num_high_motion = int(np.sum(high_motion)) + + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("High-Motion Window Detection") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Sliding windows (len=%d, stride=%d): " + "%d total, %d fully valid (NaN-free), " + "%d high-motion (median speed >= %dth percentile)", + config.window_len, + config.stride, + len(window_starts), + num_valid_windows, + num_high_motion, + int(config.pct_thresh), + ) + + if num_high_motion == 0: + logger.error("No high-motion windows found.") + return None + + run_starts, run_ends, _run_lengths = _detect_runs( + high_motion, config.min_run_len + ) + if len(run_starts) == 0: + logger.error("No runs met min_run_len=%d.", config.min_run_len) + return None + + segments_raw = _convert_runs_to_segments( + run_starts, run_ends, window_starts, config.window_len + ) + segments = _merge_segments(segments_raw) + + log_info("Detected %d merged high-motion segment(s):", len(segments)) + for i, (start, end) in enumerate(segments): + log_info(" Segment %d: frames %d - %d", i + 1, start, end) + + return segments + + +def _select_tier2_frames( + segments: np.ndarray, + tier2_valid: np.ndarray, + num_frames: int, + log_info, + log_warning, +) -> tuple[np.ndarray, np.ndarray, int] | None: + """Filter segment frames to tier-2 valid only. + + Returns (selected_frames, selected_seg_id, num_selected) or None. + """ + selected_frames, selected_seg_id = _filter_segments_tier2( + segments, tier2_valid + ) + + num_tier1_in_segs = sum( + np.sum( + (np.arange(num_frames) >= s[0]) & (np.arange(num_frames) <= s[1]) + ) + for s in segments + ) + num_selected = len(selected_frames) + + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Tier-2 Filtering on High-Motion Segments") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Frames in high-motion segments (any tier): %d", num_tier1_in_segs + ) + log_info( + "Tier-2 valid frames retained (all keypoints present): " + "%d (%.1f%% of segment frames)", + num_selected, + 100 * num_selected / max(num_tier1_in_segs, 1), + ) + + retention = num_selected / max(num_tier1_in_segs, 1) + if retention < 0.3: + log_warning( + "Tier 2 discards > 70%% of segment frames - " + "body model may be unrepresentative." + ) + + if num_selected < 2: + logger.error("Not enough tier-2 valid frames in selected segments.") + return None + + return selected_frames, selected_seg_id, num_selected + + +def _run_clustering_and_pca( + centered_skeletons: np.ndarray, + frame_sel: _FrameSelection, + config: _ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run postural analysis, clustering, and per-cluster PCA. + + Returns dict with primary_result, cluster_results, + num_clusters, primary_cluster, or None on failure. + """ + rmsd_matrix = _compute_pairwise_rmsd(centered_skeletons) + var_ratio, within_rmsds, between_rmsds, var_ratio_override = ( + _compute_postural_variance_ratio(rmsd_matrix, frame_sel.seg_ids) + ) + + rmsd_stats = { + "within": within_rmsds, + "between": between_rmsds, + "var_ratio": var_ratio, + "override": var_ratio_override, + } + _log_postural_consistency( + rmsd_stats, + config, + frame_sel.count, + log_info, + log_warning, + ) + + cluster_labels, num_clusters, primary_cluster = _decide_and_run_clustering( + centered_skeletons, + var_ratio, + frame_sel.count, + config, + log_info, + ) + + cluster_results = [] + for c in range(num_clusters): + cluster_mask = cluster_labels == c + cr = _compute_cluster_pca_and_anterior( + centered_skeletons, + cluster_mask, + frame_sel.frames, + frame_sel.seg_ids, + frame_sel.segments, + frame_sel.bbox_centroids, + ) + cluster_results.append(cr) + + pr = cluster_results[primary_cluster] + if not pr["valid"]: + logger.error("Primary cluster has invalid PCA result.") + return None + + return { + "primary_result": pr, + "cluster_results": cluster_results, + "num_clusters": num_clusters, + "primary_cluster": primary_cluster, + } + + +def _log_postural_consistency( + rmsd_stats, + config, + num_selected, + log_info, + log_warning, +): + """Log postural consistency check results.""" + within_rmsds = rmsd_stats["within"] + between_rmsds = rmsd_stats["between"] + var_ratio = rmsd_stats["var_ratio"] + var_ratio_override = rmsd_stats["override"] + + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Postural Consistency Check") + log_info("────────────────────────────────────────────────────────────") + + if len(within_rmsds) > 0: + log_info( + "Within-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(within_rmsds), + np.std(within_rmsds), + len(within_rmsds), + ) + else: + log_info("Within-segment RMSD: N/A (no within-segment pairs)") + + if len(between_rmsds) > 0: + log_info( + "Between-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(between_rmsds), + np.std(between_rmsds), + len(between_rmsds), + ) + log_info( + "Variance ratio (between/within): %.2f (threshold=%.2f)", + var_ratio, + config.postural_var_ratio_thresh, + ) + if var_ratio_override: + log_info( + " (Conservative override to zero: within-segment variance " + "is zero or no within-segment pairs)" + ) + else: + log_info("Between-segment RMSD: N/A (single segment)") + log_info("Variance ratio: N/A") + + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + if do_clustering: + log_info(" -> Variance ratio exceeds threshold. Running clustering.") + elif var_ratio > config.postural_var_ratio_thresh and num_selected < 6: + log_info( + " -> Variance ratio exceeds threshold but too few frames (%d) " + "for clustering.", + num_selected, + ) + else: + log_info(" -> Postural consistency acceptable. Using global average.") + + +def _decide_and_run_clustering( + centered_skeletons, + var_ratio, + num_selected, + config, + log_info, +): + """Decide whether to cluster; run k-medoids if triggered.""" + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + + if not do_clustering: + return np.zeros(num_selected, dtype=int), 1, 0 + + ( + cluster_labels, + num_clusters, + primary_cluster, + best_silhouette, + silhouette_scores, + ) = _perform_postural_clustering(centered_skeletons, config.max_clusters) + + for k, sil in silhouette_scores: + if np.isnan(sil): + log_info(" k=%d: clustering failed.", k) + else: + log_info(" k=%d: mean silhouette = %.4f", k, sil) + + if num_clusters > 1: + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + log_info( + " Selected k=%d clusters (silhouette=%.4f). " + "Primary cluster=%d (%d frames)", + num_clusters, + best_silhouette, + primary_cluster + 1, + cluster_counts[primary_cluster], + ) + else: + log_info( + " Clustering did not improve separation (best_sil=%.4f). " + "Using global average.", + best_silhouette, + ) + + return cluster_labels, num_clusters, primary_cluster + + +def _log_anterior_report( + pr, + cluster_results, + num_clusters, + primary_cluster, + config, + log_info, + log_warning, +): + """Log anterior direction detection and cluster agreement.""" + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Anterior Direction Inference (Velocity Voting)") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Centroid velocity projections onto PC1: " + "%d positive (+PC1), %d negative (−PC1)", + pr["num_positive"], + pr["num_negative"], + ) + + vote_margin_str = f"Vote margin M: {pr['vote_margin']:.4f}" + if pr["vote_margin"] < config.confidence_floor: + vote_margin_str += ( + f" ** BELOW CONFIDENCE FLOOR ({config.confidence_floor:.2f}) " + "— anterior assignment is unreliable **" + ) + log_warning( + "Vote margin M = %.4f is below confidence floor %.2f — " + "anterior assignment is unreliable.", + pr["vote_margin"], + config.confidence_floor, + ) + log_info(vote_margin_str) + log_info( + "Resultant length R: %.4f (0 = omnidirectional, 1 = unidirectional)", + pr["resultant_length"], + ) + log_info( + "Inferred anterior direction: %sPC1 " + "(strict majority; ties default to −PC1)", + "+" if pr["anterior_sign"] > 0 else "−", + ) + + if num_clusters <= 1: + return + + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Inter-Cluster Anterior Polarity Agreement") + log_info("────────────────────────────────────────────────────────────") + signs = [cr["anterior_sign"] for cr in cluster_results if cr["valid"]] + if len(set(signs)) == 1: + log_info( + "All %d clusters AGREE on anterior polarity.", + num_clusters, + ) + else: + log_info( + "DISAGREEMENT: clusters assign different anterior polarities." + ) + for c, cr in enumerate(cluster_results): + if cr["valid"]: + log_info( + " Cluster %d (%d frames): anterior = %sPC1, " + "vote_margin M = %.4f, resultant_length R = %.4f", + c + 1, + cr["n_frames"], + "+" if cr["anterior_sign"] > 0 else "−", + cr["vote_margin"], + cr["resultant_length"], + ) + log_info( + " Primary result from cluster %d (largest).", + primary_cluster + 1, + ) + + +def _log_step1_report(pair_report, config, valid_nodes, log_info): + """Log Step 1 lateral filter results.""" + num_valid = len(valid_nodes) + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(num_valid, 1) + + pass_strs = [] + fail_strs = [] + for node_i in valid_nodes: + lat_norm = pair_report.lateral_offsets_norm[node_i] + if lat_norm <= config.lateral_thresh: + pass_strs.append(f"{node_i}({lat_norm:.2f})") + else: + fail_strs.append(f"{node_i}({lat_norm:.2f})") + + log_info("") + log_info( + "Step 1 — Lateral Alignment Filter (lateral_thresh=%.2f): " + "%d of %d valid nodes pass [loss=%.0f%%]", + config.lateral_thresh, + num_candidates, + num_valid, + 100 * step1_loss, + ) + log_info( + " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" + ) + if pass_strs: + log_info(" PASS: %s", ", ".join(pass_strs)) + if fail_strs: + log_info(" FAIL: %s", ", ".join(fail_strs)) + + +def _log_step2_report(pair_report, config, log_info): + """Log Step 2 opposite-sides results.""" + num_candidates = len(pair_report.sorted_candidate_nodes) + num_possible_pairs = num_candidates * (num_candidates - 1) // 2 + num_valid_pairs = len(pair_report.valid_pairs) + step2_loss = 1 - num_valid_pairs / max(num_possible_pairs, 1) + m = pair_report.midpoint_pc1 + + plus_strs = [] + minus_strs = [] + for node_i in pair_report.sorted_candidate_nodes: + pc1_rel = pair_report.pc1_coords[node_i] - m + if pc1_rel > 0: + plus_strs.append(f"{node_i}({pc1_rel:+.1f})") + else: + minus_strs.append(f"{node_i}({pc1_rel:+.1f})") + + log_info("") + log_info( + "Step 2 — Opposite-Sides Constraint (AP midpoint=%.2f): " + "%d of %d candidate pairs on opposite sides [loss=%.0f%%]", + m, + num_valid_pairs, + num_possible_pairs, + 100 * step2_loss, + ) + if plus_strs: + log_info(" + side (anterior of midpoint): %s", ", ".join(plus_strs)) + if minus_strs: + log_info(" - side (posterior of midpoint): %s", ", ".join(minus_strs)) + + +def _log_step3_report(pair_report, config, log_info): + """Log Step 3 distal/proximal classification results.""" + num_distal = len(pair_report.distal_pairs) + num_proximal = len(pair_report.proximal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_distal_frac = num_distal / max(num_valid_pairs, 1) + + log_info("") + log_info( + "Step 3 — Distal/Proximal Classification (edge_thresh=%.2f): " + "%d distal, %d proximal [distal fraction=%.0f%%]", + config.edge_thresh, + num_distal, + num_proximal, + 100 * step3_distal_frac, + ) + + for idx in range(num_valid_pairs): + node_i, node_j = pair_report.valid_pairs[idx] + d_i = pair_report.midline_dist_norm[node_i] + d_j = pair_report.midline_dist_norm[node_j] + min_d = min(d_i, d_j) + sep = pair_report.valid_pairs_internode_dist[idx] + status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" + log_info( + " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", + node_i, + node_j, + min_d, + sep, + status, + ) + + +def _log_loss_summary( + step1_loss, step2_loss, step3_frac, step1_failed, step2_failed, log_info +): + """Log cumulative filtering loss summary.""" + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Filtering Loss Summary") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Step 1 (Lateral Filter): %.0f%% of valid nodes eliminated", + 100 * step1_loss, + ) + if not step1_failed: + log_info( + "Step 2 (Opposite-Sides): %.0f%% of candidate pairs eliminated", + 100 * step2_loss, + ) + if not step1_failed and not step2_failed: + log_info( + "Step 3 (Distal/Proximal): %.0f%% of surviving pairs are distal", + 100 * step3_frac, + ) + + +def _log_order_check( + pair_report, from_idx, to_idx, from_name, to_name, log_info +): + """Log AP ordering check for the input pair.""" + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("Order Check: is from_node posterior to to_node?") + log_info("────────────────────────────────────────────────────────────") + ap_from = pair_report.ap_coords[from_idx] + ap_to = pair_report.ap_coords[to_idx] + if np.isnan(ap_from) or np.isnan(ap_to): + log_info("Order check: cannot evaluate (invalid node coordinates)") + return + + log_info( + "AP coords: from_node %s[%d]=%.2f, to_node %s[%d]=%.2f", + from_name, + from_idx, + ap_from, + to_name, + to_idx, + ap_to, + ) + neon_green = "\033[38;5;46m" + crimson = "\033[38;5;9m" + reset = "\033[0m" + if pair_report.input_pair_order_matches_inference: + log_info( + f"{neon_green}[%d, %d]: CONSISTENT — inference agrees that " + f"from_node is posterior (lower AP coord), " + f"to_node is anterior{reset}", + from_idx, + to_idx, + ) + else: + log_info( + f"{crimson}[%d, %d]: INCONSISTENT — inference suggests " + f"from_node is anterior (higher AP coord), " + f"to_node is posterior{reset}", + from_idx, + to_idx, + ) + log_info( + " -> Inferred posterior→anterior order would be [%d, %d]", + to_idx, + from_idx, + ) + + +def _log_input_node_status( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log whether each input node passed the lateral filter.""" + lat_from = pair_report.lateral_offsets_norm[from_idx] + lat_to = pair_report.lateral_offsets_norm[to_idx] + from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh + to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh + + if from_pass and to_pass: + return + + fail_nodes = [] + if not from_pass: + fail_nodes.append(f"{from_idx}({lat_from:.2f})") + if not to_pass: + fail_nodes.append(f"{to_idx}({lat_to:.2f})") + log_info( + " -> Input node(s) FAILED lateral filter: %s", + ", ".join(fail_nodes), + ) + + +def _log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, +): + """Log Step 2 and Step 3 results when Step 1 succeeded. + + Returns (step2_loss, step3_frac, step2_failed). + """ + _log_step2_report(pair_report, config, log_info) + + if ( + pair_report.input_pair_in_candidates + and not pair_report.input_pair_opposite_sides + ): + log_info(" -> Input nodes on SAME side of AP midpoint") + + num_possible = num_candidates * (num_candidates - 1) // 2 + step2_loss = 1 - len(pair_report.valid_pairs) / max(num_possible, 1) + step2_failed = pair_report.failure_step.startswith("Step 2") + + if step2_failed: + log_info("") + log_info("Step 3: not evaluated (Step 2 failed)") + return step2_loss, 0.0, True + + step3_frac = _log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, + ) + return step2_loss, step3_frac, False + + +def _log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log Step 3 results and check input pair proximal status. + + Returns step3_frac. + """ + _log_step3_report(pair_report, config, log_info) + num_distal = len(pair_report.distal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_frac = num_distal / max(num_valid_pairs, 1) + + is_candidate = pair_report.input_pair_in_candidates + is_opposite = pair_report.input_pair_opposite_sides + is_proximal = not pair_report.input_pair_is_distal + if is_candidate and is_opposite and is_proximal: + d_from = pair_report.midline_dist_norm[from_idx] + d_to = pair_report.midline_dist_norm[to_idx] + log_info( + " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", + min(d_from, d_to), + config.edge_thresh, + ) + + return step3_frac + + +def _log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + log_info, +): + """Log the complete AP node pair evaluation report.""" + log_info("") + log_info("────────────────────────────────────────────────────────────") + log_info("AP Node-Pair Filter Cascade (3-Step Evaluation)") + log_info("────────────────────────────────────────────────────────────") + log_info( + "Input pair: [%d, %d] (%s → %s, claimed posterior → anterior)", + from_idx, + to_idx, + from_name, + to_name, + ) + + step1_failed = pair_report.failure_step.startswith("Step 1") + + valid_nodes = np.where(~np.isnan(pair_report.lateral_offsets_norm))[0] + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(len(valid_nodes), 1) + + _log_step1_report(pair_report, config, valid_nodes, log_info) + _log_input_node_status(pair_report, config, from_idx, to_idx, log_info) + + step2_loss = 0.0 + step3_frac = 0.0 + step2_failed = False + + if step1_failed: + log_info("") + log_info("Step 2-3: not evaluated (Step 1 failed)") + else: + step2_loss, step3_frac, step2_failed = _log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, + ) + + _log_loss_summary( + step1_loss, + step2_loss, + step3_frac, + step1_failed, + step2_failed, + log_info, + ) + _log_order_check( + pair_report, + from_idx, + to_idx, + from_name, + to_name, + log_info, + ) + + +# ── Main validation function ──────────────────────────────────────────── + + +def _validate_ap( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, + config: _ValidateAPConfig | None = None, + verbose: bool = True, +) -> dict: + """Validate an anterior-posterior keypoint pair using body-axis inference. + + This function implements a prior-free body-axis inference pipeline that: + 1. Identifies high-motion segments using tiered validity and sliding + windows + 2. Optionally performs postural clustering via k-medoids + 3. Infers the anterior direction using velocity projection voting + 4. Evaluates the candidate AP keypoint pair through a 3-step filter + cascade + + Parameters + ---------- + data : xarray.DataArray + Position data for a single individual. + from_node : int or str + Index or name of the posterior keypoint. + to_node : int or str + Index or name of the anterior keypoint. + config : _ValidateAPConfig, optional + Configuration parameters. If None, uses defaults. + verbose : bool, default=True + If True, log detailed validation output to console. + + Returns + ------- + dict + Validation results including success, anterior_sign, + vote_margin, resultant_length, pair_report, etc. + + """ + if config is None: + config = _ValidateAPConfig() + + log_lines: list[str] = [] + + def _log_info(msg, *args): + """Log an informational message.""" + line = msg % args if args else msg + log_lines.append(line) + if verbose: + print(line) + + def _log_warning(msg, *args): + """Log a warning message with ANSI coloring.""" + orange = "\033[38;5;214m" + reset = "\033[0m" + line = f"{orange}WARNING: {msg % args if args else msg}{reset}" + log_lines.append(line) + if verbose: + print(line) + + # Prepare inputs + ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + _keypoint_names, + num_frames, + ) = _prepare_validation_inputs(data, from_node, to_node) + + n_keypoints = keypoints.shape[1] + result: dict = { + "success": False, + "anterior_sign": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "num_selected_frames": 0, + "num_clusters": 1, + "primary_cluster": 0, + "pair_report": _APNodePairReport(), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "error_msg": "", + "log_lines": log_lines, + } + + # Motion segmentation + seg = _run_motion_segmentation( + keypoints, + num_frames, + config, + _log_info, + _log_warning, + ) + if seg is None: + result["error_msg"] = "Motion segmentation failed." + return result + + # Tier-2 frame selection + t2 = _select_tier2_frames( + seg["segments"], + seg["tier2_valid"], + num_frames, + _log_info, + _log_warning, + ) + if t2 is None: + result["error_msg"] = "Not enough tier-2 valid frames." + return result + selected_frames, selected_seg_id, num_selected = t2 + result["num_selected_frames"] = num_selected + + # Build centered skeletons + _selected_centroids, centered_skeletons = _build_centered_skeletons( + keypoints, selected_frames + ) + + # Bundle frame selection data + frame_sel = _FrameSelection( + frames=selected_frames, + seg_ids=selected_seg_id, + segments=seg["segments"], + bbox_centroids=seg["bbox_centroids"], + count=num_selected, + ) + + # Postural clustering + PCA + anterior inference + pca = _run_clustering_and_pca( + centered_skeletons, + frame_sel, + config, + _log_info, + _log_warning, + ) + if pca is None: + result["error_msg"] = "Primary cluster PCA failed." + return result + + pr = pca["primary_result"] + result["anterior_sign"] = pr["anterior_sign"] + result["vote_margin"] = pr["vote_margin"] + result["resultant_length"] = pr["resultant_length"] + result["circ_mean_dir"] = pr["circ_mean_dir"] + result["vel_projs_pc1"] = pr["vel_projs_pc1"] + result["PC1"] = pr["PC1"] + result["PC2"] = pr["PC2"] + result["avg_skeleton"] = pr["avg_skeleton"] + result["num_clusters"] = pca["num_clusters"] + result["primary_cluster"] = pca["primary_cluster"] + + # Log anterior inference + _log_anterior_report( + pr, + pca["cluster_results"], + pca["num_clusters"], + pca["primary_cluster"], + config, + _log_info, + _log_warning, + ) + + # AP node-pair evaluation + pair_report = _evaluate_ap_node_pair( + pr["avg_skeleton"], + pr["PC1"], + pr["anterior_sign"], + pr["valid_shape_rows"], + from_idx, + to_idx, + config, + ) + result["pair_report"] = pair_report + + _log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + _log_info, + ) + + result["success"] = True + return result diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 8a785d115..a102e1b5d 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -1,6 +1,8 @@ # test_collective.py """Tests for the collective behavior metrics module.""" +from typing import Any + import numpy as np import pytest import xarray as xr @@ -17,6 +19,33 @@ def _get_space_labels(n_space: int, space: list[str] | None) -> list[str]: raise ValueError("Provide explicit `space` labels for non-2D data.") +def _build_coords( + data: np.ndarray, + time: list | None, + space: list[str] | None, + individuals: list | None, + keypoints: list[str] | None = None, +) -> dict: + """Build coordinate dict for a position DataArray.""" + n_time, n_space = data.shape[0], data.shape[1] + coords: dict = { + "time": time if time is not None else list(range(n_time)), + "space": _get_space_labels(n_space, space), + } + if data.ndim == 4: + n_keypoints = data.shape[2] + coords["keypoints"] = keypoints or [ + f"kp_{i}" for i in range(n_keypoints) + ] + n_individuals = data.shape[3] + else: + n_individuals = data.shape[2] + coords["individuals"] = individuals or [ + f"id_{i}" for i in range(n_individuals) + ] + return coords + + def _make_position_dataarray( data: np.ndarray, *, @@ -27,41 +56,22 @@ def _make_position_dataarray( ) -> xr.DataArray: """Create a position DataArray for tests.""" data = np.asarray(data, dtype=float) - n_time, n_space = data.shape[0], data.shape[1] - if data.ndim == 3: - n_individuals = data.shape[2] - ind = individuals or [f"id_{i}" for i in range(n_individuals)] - return xr.DataArray( - data, - dims=["time", "space", "individuals"], - coords={ - "time": time if time else list(range(n_time)), - "space": _get_space_labels(n_space, space), - "individuals": ind, - }, - name="position", + dims_map = { + 3: ["time", "space", "individuals"], + 4: ["time", "space", "keypoints", "individuals"], + } + if data.ndim not in dims_map: + raise ValueError( + "Expected data with shape (time, space, individuals) or " + "(time, space, keypoints, individuals)." ) - if data.ndim == 4: - n_keypoints, n_individuals = data.shape[2], data.shape[3] - kp = keypoints or [f"kp_{i}" for i in range(n_keypoints)] - ind = individuals or [f"id_{i}" for i in range(n_individuals)] - return xr.DataArray( - data, - dims=["time", "space", "keypoints", "individuals"], - coords={ - "time": time if time else list(range(n_time)), - "space": _get_space_labels(n_space, space), - "keypoints": kp, - "individuals": ind, - }, - name="position", - ) - - raise ValueError( - "Expected data with shape (time, space, individuals) or " - "(time, space, keypoints, individuals)." + return xr.DataArray( + data, + dims=dims_map[data.ndim], + coords=_build_coords(data, time, space, individuals, keypoints), + name="position", ) @@ -261,7 +271,7 @@ def test_body_axis_keypoints_must_be_distinct(self, keypoint_positions): ) @pytest.mark.parametrize( - "displacement_frames,expected_exception", + ("displacement_frames", "expected_exception"), [ (0, ValueError), (-1, ValueError), @@ -329,6 +339,64 @@ def test_empty_keypoints_dimension_raises_in_displacement_mode(self): kinematics.compute_polarization(data) +class TestValidateAPConfig: + """Tests for the _ValidateAPConfig dataclass parameter validation.""" + + @pytest.mark.parametrize( + ("field", "value"), + [ + ("min_valid_frac", -0.1), + ("min_valid_frac", 1.1), + ("window_len", 0), + ("window_len", -5), + ("window_len", 2.5), + ("stride", 0), + ("stride", -1), + ("stride", 1.5), + ("pct_thresh", -1), + ("pct_thresh", 101), + ("min_run_len", 0), + ("min_run_len", -1), + ("min_run_len", 1.5), + ("postural_var_ratio_thresh", 0), + ("postural_var_ratio_thresh", -1), + ("max_clusters", 0), + ("max_clusters", 2.5), + ("confidence_floor", -0.1), + ("confidence_floor", 1.1), + ("lateral_thresh", -0.1), + ("lateral_thresh", 1.1), + ("edge_thresh", -0.1), + ("edge_thresh", 1.1), + ], + ) + def test_invalid_config_values_raise(self, field: str, value: Any) -> None: + """Invalid config values should raise ValueError.""" + from movement.kinematics.collective import _ValidateAPConfig + + kwargs = {field: value} + with pytest.raises(ValueError, match="must be"): + _ValidateAPConfig(**kwargs) + + def test_valid_config_does_not_raise(self) -> None: + """Valid config values should not raise any error.""" + from movement.kinematics.collective import _ValidateAPConfig + + # Should not raise + _ValidateAPConfig( + min_valid_frac=0.5, + window_len=10, + stride=2, + pct_thresh=50.0, + min_run_len=2, + postural_var_ratio_thresh=1.5, + max_clusters=3, + confidence_floor=0.2, + lateral_thresh=0.3, + edge_thresh=0.2, + ) + + class TestComputePolarizationBehavior: """Tests for polarization computation behavior.""" @@ -612,113 +680,92 @@ def test_polarization_is_invariant_to_global_rotation( equal_nan=True, ) - def test_body_axis_invariance_to_translation_scaling_rotation( - self, - ): - """Body-axis polarization is invariant to translation/scaling/rotation. + def _body_axis_baseline(self): + """Build body-axis test data and return (da, pol, angle). - Mean body angle is invariant to translation and positive scaling, and - rotates by the same amount under global planar rotation. + Three individuals with body axes: +x, +x, +y. + Vector sum = (2, 1), polarization = sqrt(5)/3, angle = atan2(1,2). + Absolute positions differ across frames to test body-axis + independence from location. """ - # Three individuals with body axes: +x, +x, +y. - # This gives a nontrivial baseline: - # vector sum = (2, 1) - # polarization = sqrt(5) / 3 - # mean angle = atan2(1, 2) - # - # Absolute positions differ across frames to ensure we are really - # testing body-axis heading (target - origin), not any accidental - # dependence on absolute location. data = np.array( [ [ - [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], # x - [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], # y + [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], + [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], ], [ - [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], # x - [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], # y + [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], + [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], ], ], dtype=float, ) da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) - - pol_base, angle_base = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( da, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + return da, pol, angle - expected_pol = np.sqrt(5) / 3 - expected_angle = np.arctan2(1.0, 2.0) - - np.testing.assert_allclose(pol_base.values, expected_pol, atol=1e-10) + def test_body_axis_baseline_matches_expected_values(self): + """Body-axis polarization and angle match hand-computed values.""" + _da, pol, angle = self._body_axis_baseline() + np.testing.assert_allclose(pol.values, np.sqrt(5) / 3, atol=1e-10) np.testing.assert_allclose( - angle_base.values, expected_angle, atol=1e-10 + angle.values, np.arctan2(1.0, 2.0), atol=1e-10 ) - # Global translation: should not affect body-axis vectors. + def test_body_axis_invariance_to_translation(self): + """Global translation does not change body-axis polarization.""" + da, pol_base, angle_base = self._body_axis_baseline() translated = da.copy() translated.loc[{"space": "x"}] = translated.sel(space="x") + 123.4 translated.loc[{"space": "y"}] = translated.sel(space="y") - 56.7 - pol_translated, angle_translated = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( translated, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_translated.values, pol_base.values, atol=1e-10 - ) - np.testing.assert_allclose( - angle_translated.values, angle_base.values, atol=1e-10 - ) + def test_body_axis_invariance_to_positive_scaling(self): + """Positive scaling preserves body-axis polarization and angle.""" + da, pol_base, angle_base = self._body_axis_baseline() - # Positive scaling: should preserve directions and therefore preserve - # polarization and angle. - scaled = da * 4.2 - - pol_scaled, angle_scaled = kinematics.compute_polarization( - scaled, + pol, angle = kinematics.compute_polarization( + da * 4.2, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) + np.testing.assert_allclose(angle.values, angle_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_scaled.values, pol_base.values, atol=1e-10 - ) - np.testing.assert_allclose( - angle_scaled.values, angle_base.values, atol=1e-10 - ) + def test_body_axis_angle_rotates_under_global_rotation(self): + """Polarization preserved, angle shifts by pi/2. - # Global 90-degree rotation: polarization magnitude should be - # unchanged, and mean angle should rotate by +pi/2 (with wraparound). + Tests behavior under 90-degree rotation. + """ + da, pol_base, angle_base = self._body_axis_baseline() rotated = da.copy() x = da.sel(space="x") y = da.sel(space="y") rotated.loc[{"space": "x"}] = -y rotated.loc[{"space": "y"}] = x - pol_rotated, angle_rotated = kinematics.compute_polarization( + pol, angle = kinematics.compute_polarization( rotated, body_axis_keypoints=("tail_base", "neck"), return_angle=True, ) + np.testing.assert_allclose(pol.values, pol_base.values, atol=1e-10) - np.testing.assert_allclose( - pol_rotated.values, pol_base.values, atol=1e-10 - ) - - expected_rotated_angle = angle_base.values + (np.pi / 2) - expected_rotated_angle = ( - (expected_rotated_angle + np.pi) % (2 * np.pi) - ) - np.pi - - np.testing.assert_allclose( - angle_rotated.values, expected_rotated_angle, atol=1e-10 - ) + expected = angle_base.values + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + np.testing.assert_allclose(angle.values, expected, atol=1e-10) class TestHeadingSourceSelection: @@ -852,10 +899,8 @@ def test_first_n_frames_are_nan(self, aligned_positions): displacement_frames=2, return_angle=True, ) - assert np.isnan(polarization.values[0]) - assert np.isnan(polarization.values[1]) - assert np.isnan(mean_angle.values[0]) - assert np.isnan(mean_angle.values[1]) + assert np.all(np.isnan(polarization.values[:2])) + assert np.all(np.isnan(mean_angle.values[:2])) assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) assert np.allclose(mean_angle.values[2:], 0.0, atol=1e-10) @@ -936,13 +981,15 @@ def test_return_angle_true_returns_named_pair(self, aligned_positions): ) assert isinstance(polarization, xr.DataArray) assert isinstance(mean_angle, xr.DataArray) - assert polarization.name == "polarization" - assert mean_angle.name == "mean_angle" + assert (polarization.name, mean_angle.name) == ( + "polarization", + "mean_angle", + ) assert polarization.dims == ("time",) assert mean_angle.dims == ("time",) @pytest.mark.parametrize( - "data,expected_angle,use_abs", + ("data", "expected_angle", "use_abs"), [ ( np.array( From 1bd1618e7adc11c56eeb7df63edaf30ffbd4f926 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Apr 2026 05:21:38 +0000 Subject: [PATCH 17/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/kinematics/collective.py | 1 + 1 file changed, 1 insertion(+) diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index fa136bedb..3cb0b8f3c 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -500,6 +500,7 @@ def _compute_mean_angle( mean_angle = np.rad2deg(mean_angle) return mean_angle.rename("mean_angle") + def _run_ap_validation( data: xr.DataArray, normalized_keypoints: tuple[Hashable, Hashable], From 14e2674ae7edf118c16d753f0a7931dbdf6c7003 Mon Sep 17 00:00:00 2001 From: khan-u Date: Sat, 4 Apr 2026 01:27:21 -0700 Subject: [PATCH 18/21] refactor(collective) extract out AP validation into body_axis.py --- movement/kinematics/body_axis.py | 2910 ++++++++++++++++ movement/kinematics/collective.py | 2911 +---------------- .../test_kinematics/test_body_axis.py | 62 + .../test_kinematics/test_collective.py | 60 - 4 files changed, 2991 insertions(+), 2952 deletions(-) create mode 100644 movement/kinematics/body_axis.py create mode 100644 tests/test_unit/test_kinematics/test_body_axis.py diff --git a/movement/kinematics/body_axis.py b/movement/kinematics/body_axis.py new file mode 100644 index 000000000..054af965e --- /dev/null +++ b/movement/kinematics/body_axis.py @@ -0,0 +1,2910 @@ +"""Body-axis inference and anterior-posterior validation for pose data. + +This module provides infrastructure for validating user-supplied body-axis +keypoint pairs by inferring the anterior-posterior (AP) axis from motion +data. It uses a prior-free approach combining: + +1. High-motion segment detection via tiered validity and sliding windows +2. Postural clustering via k-medoids (when posture varies across segments) +3. PCA-based body-axis extraction from centered skeletons +4. Velocity projection voting to infer anterior direction +5. A 3-step filter cascade to evaluate candidate AP keypoint pairs + +""" + +from collections.abc import Hashable +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import xarray as xr + +from movement.utils.logging import logger + +# Separator line for log output formatting +_LOG_SEPARATOR = "\u2500" * 60 + + +# Configuration and Data Classes +# ────────────────────────────── + + +@dataclass +class ValidateAPConfig: + """Configuration for the validate_ap function. + + Parameters + ---------- + min_valid_frac : float, default=0.6 + Minimum fraction of keypoints that must be present for a frame + to qualify as tier-1 valid. + window_len : int, default=50 + Number of speed samples per sliding window. + stride : int, default=5 + Step size between consecutive sliding window start positions. + pct_thresh : float, default=85.0 + Percentile threshold applied to valid-window median speeds for + high-motion classification. + min_run_len : int, default=1 + Minimum number of consecutive qualifying windows required to + form a valid run. + postural_var_ratio_thresh : float, default=2.0 + Between-segment to within-segment RMSD variance ratio above which + postural clustering is triggered. + max_clusters : int, default=4 + Upper bound on the number of clusters to evaluate during k-medoids. + confidence_floor : float, default=0.1 + Vote margin below which the anterior inference is flagged as + unreliable. + lateral_thresh : float, default=0.4 + Normalized lateral offset ceiling for the Step 1 lateral alignment + filter. + edge_thresh : float, default=0.3 + Normalized midpoint distance floor for the Step 3 distal/proximal + classification. + + """ + + min_valid_frac: float = 0.6 + window_len: int = 50 + stride: int = 5 + pct_thresh: float = 85.0 + min_run_len: int = 1 + postural_var_ratio_thresh: float = 2.0 + max_clusters: int = 4 + confidence_floor: float = 0.1 + lateral_thresh: float = 0.4 + edge_thresh: float = 0.3 + + def __post_init__(self) -> None: + """Validate configuration parameters.""" + for name in ( + "min_valid_frac", + "confidence_floor", + "lateral_thresh", + "edge_thresh", + ): + value = getattr(self, name) + if not (0 <= value <= 1): + raise ValueError( + f"{name} must be between 0 and 1, got {value}" + ) + + for name in ("window_len", "stride", "min_run_len", "max_clusters"): + value = getattr(self, name) + if not isinstance(value, int) or value <= 0: + raise ValueError( + f"{name} must be a positive integer, got {value}" + ) + + if not (0 <= self.pct_thresh <= 100): + raise ValueError( + f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" + ) + + if self.postural_var_ratio_thresh <= 0: + raise ValueError( + f"postural_var_ratio_thresh must be positive, " + f"got {self.postural_var_ratio_thresh}" + ) + + +@dataclass +class FrameSelection: + """Selected frames from high-motion segmentation and tier-2 filtering. + + Bundles the frame indices, segment assignments, and related arrays + produced by the segmentation pipeline for downstream consumption + (skeleton construction, postural clustering, velocity recomputation). + + Attributes + ---------- + frames : np.ndarray + Array of selected frame indices (tier-2 valid, within segments). + seg_ids : np.ndarray + Segment ID (0-indexed) for each selected frame. + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + count : int + Number of selected frames. + + """ + + frames: np.ndarray + seg_ids: np.ndarray + segments: np.ndarray + bbox_centroids: np.ndarray + count: int + + +@dataclass +class APNodePairReport: + """Report from the AP node-pair evaluation pipeline. + + This dataclass holds all results from the 3-step filter cascade + used to evaluate a candidate anterior-posterior keypoint pair. + + Attributes + ---------- + success : bool + Whether the evaluation pipeline completed successfully. + failure_step : str + Name of the step at which evaluation failed, if any. + failure_reason : str + Reason for failure, if any. + scenario : int + Scenario number (1-13) from the mutually exclusive outcomes. + outcome : str + Either "accept" or "warn". + warning_message : str + Warning message, if applicable. + sorted_candidate_nodes : np.ndarray + Indices of candidate nodes after Step 1 filtering, sorted by + ascending normalized lateral offset. + valid_pairs : np.ndarray + Array of shape (n_pairs, 2) containing valid node pairs after + Step 2 filtering. + valid_pairs_internode_dist : np.ndarray + Internode separation (AP distance) for each valid pair. + input_pair_in_candidates : bool + Whether the input pair survived Step 1 filtering. + input_pair_opposite_sides : bool + Whether the input pair lies on opposite sides of the midpoint. + input_pair_separation_abs : float + Absolute AP separation of the input pair. + input_pair_is_distal : bool + Whether the input pair is classified as distal in Step 3. + input_pair_rank : int + Rank of the input pair by internode separation (1 = largest). + input_pair_order_matches_inference : bool + Whether from_node has a lower AP coordinate than to_node + (i.e. from_node is more posterior). True means the input pair + ordering is consistent with the inferred AP axis. + pc1_coords : np.ndarray + PC1 coordinates for each keypoint. + ap_coords : np.ndarray + AP (anterior-posterior) coordinates for each keypoint. + lateral_offsets : np.ndarray + Unsigned lateral offset from body axis for each keypoint. + lateral_offsets_norm : np.ndarray + Normalized lateral offsets (0 = nearest to axis, 1 = farthest). + lateral_offset_min : float + Minimum lateral offset among valid keypoints. + lateral_offset_max : float + Maximum lateral offset among valid keypoints. + midpoint_pc1 : float + AP reference midpoint (average of min and max PC1 projections). + pc1_min : float + Minimum PC1 projection among valid keypoints. + pc1_max : float + Maximum PC1 projection among valid keypoints. + midline_dist_norm : np.ndarray + Normalized distance from midpoint for each keypoint. + midline_dist_max : float + Maximum absolute distance from midpoint. + distal_pairs : np.ndarray + Array of distal pairs (both nodes at or above edge_thresh). + proximal_pairs : np.ndarray + Array of proximal pairs (at least one node below edge_thresh). + max_separation_distal_nodes : np.ndarray + Node indices of the maximum-separation distal pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation_distal : float + Internode separation of the max-separation distal pair. + max_separation_nodes : np.ndarray + Node indices of the overall maximum-separation pair, ordered + so that element 0 is posterior (lower AP coord) and element 1 + is anterior (higher AP coord). + max_separation : float + Internode separation of the overall max-separation pair. + + """ + + success: bool = False + failure_step: str = "" + failure_reason: str = "" + scenario: int = 0 + outcome: str = "" + warning_message: str = "" + + sorted_candidate_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + valid_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + valid_pairs_internode_dist: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + + input_pair_in_candidates: bool = False + input_pair_opposite_sides: bool = False + input_pair_separation_abs: float = np.nan + input_pair_is_distal: bool = False + input_pair_rank: int = 0 + input_pair_order_matches_inference: bool = False + + pc1_coords: np.ndarray = field(default_factory=lambda: np.array([])) + ap_coords: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_offsets_norm: np.ndarray = field( + default_factory=lambda: np.array([]) + ) + lateral_offset_min: float = np.nan + lateral_offset_max: float = np.nan + midpoint_pc1: float = np.nan + pc1_min: float = np.nan + pc1_max: float = np.nan + midline_dist_norm: np.ndarray = field(default_factory=lambda: np.array([])) + midline_dist_max: float = np.nan + + distal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + proximal_pairs: np.ndarray = field( + default_factory=lambda: np.zeros((0, 2), dtype=int) + ) + max_separation_distal_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation_distal: float = np.nan + max_separation_nodes: np.ndarray = field( + default_factory=lambda: np.array([], dtype=int) + ) + max_separation: float = np.nan + + +# Tiered Validity and Centroid Computation +# ───────────────────────────────────────── + + +def compute_tiered_validity( + keypoints: np.ndarray, + min_valid_frac: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute tiered validity masks for each frame. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + min_valid_frac : float + Minimum fraction of keypoints required for tier-1 validity. + + Returns + ------- + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + A frame is tier-1 valid if at least min_valid_frac of keypoints + are present AND at least 2 keypoints are present. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + A frame is tier-2 valid if all keypoints are present. + frac_present : np.ndarray + Array of shape (n_frames,) with fraction of keypoints present. + + """ + _, n_keypoints, _ = keypoints.shape + + keypoint_present = ~np.any(np.isnan(keypoints), axis=2) + n_present = np.sum(keypoint_present, axis=1) + frac_present = n_present / n_keypoints + + tier2_valid = n_present == n_keypoints + tier1_valid = (frac_present >= min_valid_frac) & (n_present >= 2) + + return tier1_valid, tier2_valid, frac_present + + +def compute_bbox_centroid( + keypoints: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute bounding-box centroids for tier-1 valid frames. + + The bounding-box centroid is the midpoint of the axis-aligned bounding + box enclosing all present keypoints. This is density-invariant, unlike + the arithmetic mean. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + NaN for non-tier-1-valid frames. + arith_centroids : np.ndarray + Array of shape (n_frames, 2) with arithmetic-mean centroids. + NaN for non-tier-1-valid frames. Used for diagnostic comparison. + centroid_discrepancy : np.ndarray + Array of shape (n_frames,) with normalized discrepancy between + bbox and arithmetic centroids (distance / bbox_diagonal). + NaN for non-tier-1-valid frames. + + """ + n_frames = keypoints.shape[0] + + bbox_centroids = np.full((n_frames, 2), np.nan) + arith_centroids = np.full((n_frames, 2), np.nan) + centroid_discrepancy = np.full(n_frames, np.nan) + + for f in range(n_frames): + if not tier1_valid[f]: + continue + + kp_f = keypoints[f] + present_mask = ~np.any(np.isnan(kp_f), axis=1) + kp_present = kp_f[present_mask] + + bbox_min = np.min(kp_present, axis=0) + bbox_max = np.max(kp_present, axis=0) + bbox_centroids[f] = (bbox_min + bbox_max) / 2 + + arith_centroids[f] = np.mean(kp_present, axis=0) + + bbox_diag = np.linalg.norm(bbox_max - bbox_min) + if bbox_diag > 0: + discrepancy = np.linalg.norm( + bbox_centroids[f] - arith_centroids[f] + ) + centroid_discrepancy[f] = discrepancy / bbox_diag + else: + centroid_discrepancy[f] = 0.0 + + return bbox_centroids, arith_centroids, centroid_discrepancy + + +# Velocity and Motion Detection +# ────────────────────────────── + + +def compute_frame_velocities( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute frame-to-frame centroid velocities and speeds. + + A velocity is valid only when both adjacent frames are tier-1 valid. + + Parameters + ---------- + bbox_centroids : np.ndarray + Array of shape (n_frames, 2) with bounding-box centroids. + tier1_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-1 valid frames. + + Returns + ------- + velocities : np.ndarray + Array of shape (n_frames - 1, 2) with velocity vectors. + Invalid velocities are NaN. + speeds : np.ndarray + Array of shape (n_frames - 1,) with speed scalars. + Invalid speeds are NaN. + + """ + velocities = np.diff(bbox_centroids, axis=0) + speed_valid = tier1_valid[:-1] & tier1_valid[1:] + velocities[~speed_valid] = np.nan + speeds = np.linalg.norm(velocities, axis=1) + + return velocities, speeds + + +def compute_sliding_window_medians( + speeds: np.ndarray, + window_len: int, + stride: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute median speeds for sliding windows. + + A window is valid only when every speed sample in that window is valid + (non-NaN), ensuring strict NaN-free content. + + Parameters + ---------- + speeds : np.ndarray + Array of shape (n_speed_samples,) with speed values. + window_len : int + Number of speed samples per sliding window. + stride : int + Step size between consecutive window start positions. + + Returns + ------- + window_starts : np.ndarray + Array of window start indices (0-indexed). + window_medians : np.ndarray + Median speed for each window. NaN for invalid windows. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + + """ + num_speed = len(speeds) + window_starts = np.arange(0, num_speed - window_len + 1, stride) + num_windows = len(window_starts) + + window_medians = np.full(num_windows, np.nan) + window_all_valid = np.zeros(num_windows, dtype=bool) + + for k in range(num_windows): + s = window_starts[k] + e = s + window_len + w = speeds[s:e] + + if np.all(~np.isnan(w)): + window_all_valid[k] = True + window_medians[k] = np.median(w) + + return window_starts, window_medians, window_all_valid + + +def detect_high_motion_windows( + window_medians: np.ndarray, + window_all_valid: np.ndarray, + pct_thresh: float, +) -> np.ndarray: + """Identify high-motion windows based on percentile threshold. + + Parameters + ---------- + window_medians : np.ndarray + Median speed for each window. + window_all_valid : np.ndarray + Boolean array indicating which windows are fully valid. + pct_thresh : float + Percentile threshold (0-100) for high-motion classification. + + Returns + ------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + + """ + valid_medians = window_medians[window_all_valid] + if len(valid_medians) == 0: + return np.zeros(len(window_medians), dtype=bool) + + thresh = np.percentile(valid_medians, pct_thresh) + high_motion = window_all_valid & (window_medians >= thresh) + + return high_motion + + +# Run and Segment Detection +# ────────────────────────── + + +def detect_runs( + high_motion: np.ndarray, + min_run_len: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Detect runs of consecutive high-motion windows. + + A run is a maximal sequence of consecutively indexed qualifying windows. + + Parameters + ---------- + high_motion : np.ndarray + Boolean array indicating high-motion windows. + min_run_len : int + Minimum number of consecutive qualifying windows for a valid run. + + Returns + ------- + run_starts : np.ndarray + Start indices of valid runs. + run_ends : np.ndarray + End indices (inclusive) of valid runs. + run_lengths : np.ndarray + Length of each valid run. + + """ + padded = np.concatenate([[False], high_motion, [False]]) + d = np.diff(padded.astype(int)) + + run_starts_all = np.nonzero(d == 1)[0] + run_ends_all = np.nonzero(d == -1)[0] - 1 + run_lengths_all = run_ends_all - run_starts_all + 1 + + valid_mask = run_lengths_all >= min_run_len + run_starts = run_starts_all[valid_mask] + run_ends = run_ends_all[valid_mask] + run_lengths = run_lengths_all[valid_mask] + + return run_starts, run_ends, run_lengths + + +def convert_runs_to_segments( + run_starts: np.ndarray, + run_ends: np.ndarray, + window_starts: np.ndarray, + window_len: int, +) -> np.ndarray: + """Convert window runs to frame segments. + + Each run is converted to a frame interval spanning from the start frame + of the first window to the end frame of the last window. + + Parameters + ---------- + run_starts : np.ndarray + Start indices of valid runs (indices into window arrays). + run_ends : np.ndarray + End indices (inclusive) of valid runs. + window_starts : np.ndarray + Start frame indices for each window. + window_len : int + Length of each window in frames. + + Returns + ------- + segments_raw : np.ndarray + Array of shape (n_runs, 2) with [frame_start, frame_end] for each run. + + """ + n_runs = len(run_starts) + segments_raw = np.zeros((n_runs, 2), dtype=int) + + for j in range(n_runs): + s_idx = run_starts[j] + e_idx = run_ends[j] + frame_start = window_starts[s_idx] + frame_end = window_starts[e_idx] + window_len + segments_raw[j] = [frame_start, frame_end] + + return segments_raw + + +def merge_segments(segments_raw: np.ndarray) -> np.ndarray: + """Merge overlapping or abutting frame segments. + + Segments are first sorted by start frame, then merged if they overlap + or abut (next start <= current end + 1). + + Parameters + ---------- + segments_raw : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + + Returns + ------- + segments : np.ndarray + Array of merged non-overlapping segments. + + """ + if len(segments_raw) == 0: + return segments_raw + + sorted_idx = np.argsort(segments_raw[:, 0]) + segments_sorted = segments_raw[sorted_idx] + + merged = [segments_sorted[0].tolist()] + + for j in range(1, len(segments_sorted)): + next_seg = segments_sorted[j] + curr_seg = merged[-1] + + if next_seg[0] <= curr_seg[1] + 1: + merged[-1][1] = max(curr_seg[1], next_seg[1]) + else: + merged.append(next_seg.tolist()) + + return np.array(merged, dtype=int) + + +def filter_segments_tier2( + segments: np.ndarray, + tier2_valid: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Filter segment frames to retain only tier-2 valid frames. + + Parameters + ---------- + segments : np.ndarray + Array of shape (n_segments, 2) with [frame_start, frame_end]. + tier2_valid : np.ndarray + Boolean array of shape (n_frames,) indicating tier-2 valid frames. + + Returns + ------- + selected_frames : np.ndarray + Array of tier-2 valid frame indices within segments. + selected_seg_id : np.ndarray + Segment ID (0-indexed) for each selected frame. + + """ + all_segment_frames: list[int] = [] + for k in range(len(segments)): + frame_start, frame_end = segments[k] + seg_frames = np.arange(frame_start, frame_end + 1) + all_segment_frames.extend(seg_frames) + + segment_frames_all = np.unique(all_segment_frames) + + tier2_mask = tier2_valid[segment_frames_all] + selected_frames = segment_frames_all[tier2_mask] + + num_selected = len(selected_frames) + selected_seg_id = np.zeros(num_selected, dtype=int) + + for j in range(num_selected): + f = selected_frames[j] + for k in range(len(segments)): + if segments[k, 0] <= f <= segments[k, 1]: + selected_seg_id[j] = k + break + + return selected_frames, selected_seg_id + + +# Skeleton Analysis +# ────────────────── + + +def build_centered_skeletons( + keypoints: np.ndarray, + selected_frames: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Build centroid-centered skeletons for selected frames. + + Uses bounding-box centroid for centering, consistent with the + segmentation centroid. + + Parameters + ---------- + keypoints : np.ndarray + Keypoint positions with shape (n_frames, n_keypoints, 2). + selected_frames : np.ndarray + Array of selected frame indices. + + Returns + ------- + selected_centroids : np.ndarray + Array of shape (num_selected, 2) with bounding-box centroids. + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2) with + centroid-centered skeleton coordinates. + + """ + num_selected = len(selected_frames) + n_keypoints = keypoints.shape[1] + + selected_centroids = np.zeros((num_selected, 2)) + centered_skeletons = np.zeros((num_selected, n_keypoints, 2)) + + for j in range(num_selected): + f = selected_frames[j] + kp_f = keypoints[f] + + bbox_min = np.min(kp_f, axis=0) + bbox_max = np.max(kp_f, axis=0) + centroid_f = (bbox_min + bbox_max) / 2 + + selected_centroids[j] = centroid_f + centered_skeletons[j] = kp_f - centroid_f + + return selected_centroids, centered_skeletons + + +def compute_pairwise_rmsd(centered_skeletons: np.ndarray) -> np.ndarray: + """Compute pairwise RMSD between all centered skeletons. + + RMSD is computed as the square root of the mean of squared entry-wise + differences between flattened skeleton vectors. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + + Returns + ------- + rmsd_matrix : np.ndarray + Symmetric matrix of shape (num_selected, num_selected) with + pairwise RMSD values. Diagonal is zero. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + rmsd_matrix = np.zeros((num_selected, num_selected)) + + for i in range(num_selected): + for j in range(i + 1, num_selected): + d = skel_flat[i] - skel_flat[j] + rmsd_val = np.sqrt(np.mean(d**2)) + rmsd_matrix[i, j] = rmsd_val + rmsd_matrix[j, i] = rmsd_val + + return rmsd_matrix + + +def compute_postural_variance_ratio( + rmsd_matrix: np.ndarray, + selected_seg_id: np.ndarray, +) -> tuple[float, np.ndarray, np.ndarray, bool]: + """Compute the between/within segment RMSD variance ratio. + + Parameters + ---------- + rmsd_matrix : np.ndarray + Pairwise RMSD matrix of shape (num_selected, num_selected). + selected_seg_id : np.ndarray + Segment ID for each selected frame. + + Returns + ------- + var_ratio : float + Ratio of between-segment to within-segment RMSD variance. + Returns 0.0 if either distribution is empty or within variance is 0. + within_rmsds : np.ndarray + Array of within-segment RMSD values. + between_rmsds : np.ndarray + Array of between-segment RMSD values. + var_ratio_override : bool + True if variance ratio was set to 0 due to edge cases. + + """ + num_selected = len(selected_seg_id) + within_rmsds_list: list[float] = [] + between_rmsds_list: list[float] = [] + + for i in range(num_selected): + for j in range(i + 1, num_selected): + if selected_seg_id[i] == selected_seg_id[j]: + within_rmsds_list.append(rmsd_matrix[i, j]) + else: + between_rmsds_list.append(rmsd_matrix[i, j]) + + within_rmsds = np.array(within_rmsds_list) + between_rmsds = np.array(between_rmsds_list) + + var_ratio_override = False + if ( + len(within_rmsds) > 0 + and len(between_rmsds) > 0 + and np.var(within_rmsds) > 0 + ): + var_ratio = np.var(between_rmsds) / np.var(within_rmsds) + else: + var_ratio = 0.0 + var_ratio_override = True + + return var_ratio, within_rmsds, between_rmsds, var_ratio_override + + +# Clustering (k-medoids) +# ─────────────────────── + + +def _update_medoid_for_cluster( + cluster: int, + labels: np.ndarray, + medoids: np.ndarray, + dist_matrix: np.ndarray, +) -> int: + """Find the optimal medoid for a single cluster.""" + cluster_mask = labels == cluster + if not np.any(cluster_mask): + return medoids[cluster] + + cluster_indices = np.nonzero(cluster_mask)[0] + cluster_dists = dist_matrix[np.ix_(cluster_indices, cluster_indices)] + total_dists = np.sum(cluster_dists, axis=1) + best_idx = np.argmin(total_dists) + return cluster_indices[best_idx] + + +def kmedoids( + data: np.ndarray, + k: int, + max_iter: int = 100, + n_init: int = 5, + random_state: int | None = None, +) -> tuple[np.ndarray, np.ndarray, float]: + """Perform k-medoids clustering. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + k : int + Number of clusters. + max_iter : int, default=100 + Maximum number of iterations. + n_init : int, default=5 + Number of random initializations. + random_state : int, optional + Random seed for reproducibility. + + Returns + ------- + labels : np.ndarray + Cluster labels for each sample (0-indexed). + medoid_indices : np.ndarray + Indices of medoid samples. + inertia : float + Sum of distances from samples to their medoids. + + """ + from scipy.spatial.distance import cdist + + rng = np.random.default_rng(random_state) + n_samples = len(data) + + dist_matrix = cdist(data, data, metric="euclidean") + + best_labels: np.ndarray | None = None + best_medoids: np.ndarray | None = None + best_inertia = np.inf + + for _ in range(n_init): + medoids = rng.choice(n_samples, size=k, replace=False) + + for _ in range(max_iter): + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + + new_medoids = np.array( + [ + _update_medoid_for_cluster(c, labels, medoids, dist_matrix) + for c in range(k) + ] + ) + + if np.array_equal(np.sort(medoids), np.sort(new_medoids)): + break + medoids = new_medoids + + distances_to_medoids = dist_matrix[:, medoids] + labels = np.argmin(distances_to_medoids, axis=1) + inertia = np.sum(distances_to_medoids[np.arange(n_samples), labels]) + + if inertia < best_inertia: + best_inertia = inertia + best_labels = labels.copy() + best_medoids = medoids.copy() + + assert best_labels is not None and best_medoids is not None + return best_labels, best_medoids, best_inertia + + +def _compute_intra_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + n_samples: int, +) -> float: + """Compute mean distance from sample i to other samples in its cluster.""" + own_cluster = labels[i] + own_mask = labels == own_cluster + if np.sum(own_mask) > 1: + return np.mean(dist_matrix[i, own_mask & (np.arange(n_samples) != i)]) + return 0.0 + + +def _compute_nearest_cluster_dist( + i: int, + labels: np.ndarray, + dist_matrix: np.ndarray, + unique_labels: np.ndarray, +) -> float: + """Compute mean distance to nearest other cluster.""" + own_cluster = labels[i] + b_i = np.inf + for cluster in unique_labels: + if cluster == own_cluster: + continue + cluster_mask = labels == cluster + if np.any(cluster_mask): + mean_dist = np.mean(dist_matrix[i, cluster_mask]) + b_i = min(b_i, mean_dist) + return b_i + + +def silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: + """Compute mean silhouette score. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, n_features). + labels : np.ndarray + Cluster labels for each sample. + + Returns + ------- + score : float + Mean silhouette score across all samples. + Returns 0.0 if clustering is degenerate. + + """ + from scipy.spatial.distance import cdist + + n_samples = len(data) + unique_labels = np.unique(labels) + n_clusters = len(unique_labels) + + if n_clusters <= 1 or n_clusters >= n_samples: + return 0.0 + + dist_matrix = cdist(data, data, metric="euclidean") + silhouette_vals = np.zeros(n_samples) + + for i in range(n_samples): + a_i = _compute_intra_cluster_dist(i, labels, dist_matrix, n_samples) + b_i = _compute_nearest_cluster_dist( + i, labels, dist_matrix, unique_labels + ) + + if b_i == np.inf: + silhouette_vals[i] = 0.0 + elif max(a_i, b_i) > 0: + silhouette_vals[i] = (b_i - a_i) / max(a_i, b_i) + else: + silhouette_vals[i] = 0.0 + + return float(np.mean(silhouette_vals)) + + +def perform_postural_clustering( + centered_skeletons: np.ndarray, + max_clusters: int, + min_silhouette: float = 0.2, +) -> tuple[np.ndarray, int, int, float, list[tuple[int, float]]]: + """Perform postural clustering using k-medoids with silhouette selection. + + Parameters + ---------- + centered_skeletons : np.ndarray + Array of shape (num_selected, n_keypoints, 2). + max_clusters : int + Maximum number of clusters to evaluate. + min_silhouette : float, default=0.2 + Minimum silhouette score to accept clustering. + + Returns + ------- + cluster_labels : np.ndarray + Cluster labels for each frame (0-indexed). + num_clusters : int + Number of clusters (1 if clustering not accepted). + primary_cluster : int + Index of largest cluster (0-indexed). + best_silhouette : float + Best silhouette score achieved. + silhouette_scores : list of (k, score) + Silhouette scores for each k evaluated. + + """ + num_selected = len(centered_skeletons) + skel_flat = centered_skeletons.reshape(num_selected, -1) + + best_k = 1 + best_sil = -np.inf + silhouette_scores = [] + + max_k = min(max_clusters, num_selected // 2) + + for k in range(2, max_k + 1): + try: + labels, _, _ = kmedoids(skel_flat, k, n_init=5) + sil = silhouette_score(skel_flat, labels) + silhouette_scores.append((k, sil)) + + if sil > best_sil: + best_sil = sil + best_k = k + except Exception: + silhouette_scores.append((k, np.nan)) + + if best_k > 1 and best_sil > min_silhouette: + cluster_labels, _, _ = kmedoids(skel_flat, best_k, n_init=10) + num_clusters = best_k + + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + primary_cluster = int(np.argmax(cluster_counts)) + else: + cluster_labels = np.zeros(num_selected, dtype=int) + num_clusters = 1 + primary_cluster = 0 + + return ( + cluster_labels, + num_clusters, + primary_cluster, + best_sil, + silhouette_scores, + ) + + +# PCA and Anterior Inference +# ─────────────────────────── + + +def compute_cluster_velocities( + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + cluster_mask: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> np.ndarray: + """Compute velocities between adjacent consecutive frames. + + Only considers frames in the same segment and cluster. Frame pairs + where both frames are consecutive (frame[i] == frame[i-1] + 1), + in the same segment, and in the same cluster contribute a velocity + vector. + + Returns + ------- + np.ndarray + Array of shape (n_velocities, 2). Empty (0, 2) if no valid pairs. + + """ + frames_c = selected_frames[cluster_mask] + seg_ids_c = selected_seg_id[cluster_mask] + velocities_list: list[np.ndarray] = [] + + for seg_k in range(len(segments)): + seg_mask = seg_ids_c == seg_k + seg_frames = np.sort(frames_c[seg_mask]) + for fi in range(1, len(seg_frames)): + if seg_frames[fi] != seg_frames[fi - 1] + 1: + continue + curr_frame = seg_frames[fi] + prev_frame = seg_frames[fi - 1] + v = bbox_centroids[curr_frame] - bbox_centroids[prev_frame] + if np.all(~np.isnan(v)): + velocities_list.append(v) + + return np.array(velocities_list) if velocities_list else np.zeros((0, 2)) + + +def infer_anterior_from_velocities( + velocities: np.ndarray, + pc1: np.ndarray, +) -> dict: + """Infer anterior direction from velocity projections onto PC1. + + Uses strict majority vote on PC1 projection signs: anterior = +PC1 + if n_positive > n_negative, else -PC1 (ties default to -PC1). + + Also computes circular statistics on velocity angles: + - resultant_length R = sqrt(C^2 + S^2) where C = mean(cos theta), + S = mean(sin theta) + - vote_margin M = |n+ - n-| / (n+ + n-) + + Returns dict with resultant_length, circ_mean_dir, vel_projs_pc1, + num_positive, num_negative, vote_margin, anterior_sign. + + """ + result: dict = { + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "vel_projs_pc1": np.array([]), + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "anterior_sign": -1, + } + if len(velocities) == 0: + return result + + vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0]) + cos_mean = np.mean(np.cos(vel_angles)) + sin_mean = np.mean(np.sin(vel_angles)) + result["resultant_length"] = np.sqrt(cos_mean**2 + sin_mean**2) + result["circ_mean_dir"] = np.arctan2(sin_mean, cos_mean) + + vel_projs = velocities @ pc1 + num_pos = int(np.sum(vel_projs > 0)) + num_neg = int(np.sum(vel_projs < 0)) + result["vel_projs_pc1"] = vel_projs + result["num_positive"] = num_pos + result["num_negative"] = num_neg + result["vote_margin"] = abs(num_pos - num_neg) / max(num_pos + num_neg, 1) + result["anterior_sign"] = +1 if num_pos > num_neg else -1 + return result + + +def compute_cluster_pca_and_anterior( + centered_skeletons: np.ndarray, + cluster_mask: np.ndarray, + selected_frames: np.ndarray, + selected_seg_id: np.ndarray, + segments: np.ndarray, + bbox_centroids: np.ndarray, +) -> dict: + """Compute SVD-based PCA and velocity-based anterior inference. + + Performs inference for one cluster. + + Performs SVD on the cluster's average centered skeleton to extract PC1/PC2, + applies the geometric sign convention, then infers the anterior direction + via velocity voting on centroid displacements projected onto PC1. + + Returns + ------- + dict + Keys: valid, n_frames, avg_skeleton, valid_shape_rows, + PC1, PC2, anterior_sign, vote_margin, resultant_length, + circ_mean_dir, velocities, vel_projs_pc1, and others. + + """ + n_keypoints = centered_skeletons.shape[1] + n_c = int(np.sum(cluster_mask)) + + result: dict = { + "valid": False, + "n_frames": n_c, + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "valid_shape_rows": np.zeros(n_keypoints, dtype=bool), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "proj_pc1": np.full(n_keypoints, np.nan), + "proj_pc2": np.full(n_keypoints, np.nan), + "anterior_sign": -1, + "num_positive": 0, + "num_negative": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "circ_mean_dir": np.nan, + "velocities": np.zeros((0, 2)), + "vel_projs_pc1": np.array([]), + } + + if n_c == 0: + return result + + skels_c = centered_skeletons[cluster_mask] + avg_skel_c = np.mean(skels_c, axis=0) + valid_shape_rows = ~np.any(np.isnan(avg_skel_c), axis=1) + + if np.sum(valid_shape_rows) < 2: + return result + + result["avg_skeleton"] = avg_skel_c + result["valid_shape_rows"] = valid_shape_rows + + valid_rows = avg_skel_c[valid_shape_rows] + _u, _s, vt = np.linalg.svd(valid_rows, full_matrices=False) + PC1 = vt[0] + PC2 = vt[1] if len(vt) > 1 else np.array([0.0, 1.0]) + + # Geometric sign convention: + # PC1 flipped so y-component >= 0 + # PC2 flipped so x-component >= 0 + if PC1[1] < 0: + PC1 = -PC1 + if PC2[0] < 0: + PC2 = -PC2 + + result["PC1"] = PC1 + result["PC2"] = PC2 + + proj_pc1 = np.full(n_keypoints, np.nan) + proj_pc2 = np.full(n_keypoints, np.nan) + proj_pc1[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC1 + proj_pc2[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC2 + result["proj_pc1"] = proj_pc1 + result["proj_pc2"] = proj_pc2 + + velocities = compute_cluster_velocities( + selected_frames, + selected_seg_id, + cluster_mask, + segments, + bbox_centroids, + ) + result["velocities"] = velocities + result.update(infer_anterior_from_velocities(velocities, PC1)) + result["valid"] = True + return result + + +# AP Node-Pair Evaluation (3-Step Filter Cascade) +# ──────────────────────────────────────────────── + + +def compute_node_projections( + report: APNodePairReport, + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, +) -> None: + """Compute raw PC1, AP-oriented, and lateral projections. + + Computes projections for all valid keypoints. + + Populates the report's coordinate arrays and determines: + - pc1_coords: raw projection onto PC1 (sign-convention only) + - ap_coords: projection onto anterior_sign * PC1 (positive = more + anterior) + - lateral_offsets: unsigned distance from the AP axis + - midpoint_pc1: average of min/max PC1 projections (AP reference point) + - input_pair_order_matches_inference: True if from_node's AP coord < + to_node's + + """ + pc1 = pc1_vec / np.linalg.norm(pc1_vec) + e_ap = anterior_sign * pc1 + e_lat = np.array([-e_ap[1], e_ap[0]]) + + valid_rows = avg_skeleton[valid_shape_rows] + report.pc1_coords[valid_shape_rows] = valid_rows @ pc1 + report.ap_coords[valid_shape_rows] = valid_rows @ e_ap + report.lateral_offsets[valid_shape_rows] = np.abs(valid_rows @ e_lat) + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_order_matches_inference = ( + report.ap_coords[from_node] < report.ap_coords[to_node] + ) + + proj_pc1_valid = report.pc1_coords[valid_shape_rows] + report.pc1_min = float(np.min(proj_pc1_valid)) + report.pc1_max = float(np.max(proj_pc1_valid)) + report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 + + +def apply_lateral_filter( + report: APNodePairReport, + valid_idx: np.ndarray, + lateral_thresh: float, +) -> np.ndarray | None: + """Step 1: Filter keypoints by normalized lateral offset. + + Returns sorted candidate node indices, or None on failure. + + """ + d_valid = report.lateral_offsets[valid_idx] + d_min = float(np.min(d_valid)) + d_max = float(np.max(d_valid)) + report.lateral_offset_min = d_min + report.lateral_offset_max = d_max + + if d_max > d_min: + d_norm = (d_valid - d_min) / (d_max - d_min) + report.lateral_offsets_norm[valid_idx] = d_norm + keep_mask = d_norm <= lateral_thresh + else: + report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) + keep_mask = np.ones(len(d_valid), dtype=bool) + + candidate_idx = np.nonzero(keep_mask)[0] + candidates = valid_idx[candidate_idx] + sorted_order = np.argsort(d_valid[candidate_idx]) + candidates = candidates[sorted_order] + report.sorted_candidate_nodes = candidates.copy() + + if len(candidates) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = ( + "Fewer than 2 candidates remained after filtering." + ) + return None + return candidates + + +def find_opposite_side_pairs( + report: APNodePairReport, + candidates: np.ndarray, + from_node: int, + to_node: int, + valid_shape_rows: np.ndarray, +) -> tuple[np.ndarray, np.ndarray] | None: + """Step 2: Find candidate pairs on opposite sides of the AP midpoint. + + Returns (pairs, seps) arrays, or None on failure. + + """ + m = report.midpoint_pc1 + report.input_pair_in_candidates = (from_node in candidates) and ( + to_node in candidates + ) + + pairs_list: list[list[int]] = [] + seps_list: list[float] = [] + for ii in range(len(candidates)): + for jj in range(ii + 1, len(candidates)): + i, j = candidates[ii], candidates[jj] + if (report.pc1_coords[i] - m) * (report.pc1_coords[j] - m) < 0: + pairs_list.append([i, j]) + seps_list.append( + abs(report.ap_coords[i] - report.ap_coords[j]) + ) + + pairs = ( + np.array(pairs_list, dtype=int) + if pairs_list + else np.zeros((0, 2), dtype=int) + ) + seps = np.array(seps_list) if seps_list else np.array([]) + report.valid_pairs = pairs + report.valid_pairs_internode_dist = seps + + if valid_shape_rows[from_node] and valid_shape_rows[to_node]: + report.input_pair_opposite_sides = ( + (report.pc1_coords[from_node] - m) + * (report.pc1_coords[to_node] - m) + ) < 0 + report.input_pair_separation_abs = abs( + report.ap_coords[from_node] - report.ap_coords[to_node] + ) + + if len(pairs) == 0: + report.failure_step = "Step 2: opposite-sides constraint" + report.failure_reason = ( + "No candidate pair lies on opposite sides of the midpoint." + ) + return None + return pairs, seps + + +def order_pair_by_ap( + pair: np.ndarray, + ap_coords: np.ndarray, +) -> np.ndarray: + """Order a node pair so element 0 is posterior (lower AP coord). + + This ensures that suggested pairs always encode the + posterior->anterior direction, matching the convention used by + ``body_axis_keypoints=(from_node, to_node)`` where from_node is + posterior and to_node is anterior. + + """ + i, j = pair + if ap_coords[i] <= ap_coords[j]: + return np.array([i, j], dtype=int) + return np.array([j, i], dtype=int) + + +def classify_distal_proximal( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + valid_shape_rows: np.ndarray, + edge_thresh: float, +) -> np.ndarray: + """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" + m = report.midpoint_pc1 + midline_dist = np.abs(report.pc1_coords - m) + d_max_midline = float(np.nanmax(midline_dist[valid_shape_rows])) + report.midline_dist_max = d_max_midline + + if d_max_midline > 0: + report.midline_dist_norm = midline_dist / d_max_midline + else: + report.midline_dist_norm = np.zeros(len(report.pc1_coords)) + + pair_is_distal = np.zeros(len(pairs), dtype=bool) + for k in range(len(pairs)): + i, j = pairs[k] + pair_is_distal[k] = ( + min(report.midline_dist_norm[i], report.midline_dist_norm[j]) + >= edge_thresh + ) + + report.distal_pairs = pairs[pair_is_distal] + report.proximal_pairs = pairs[~pair_is_distal] + + if len(seps) > 0: + idx_max = int(np.argmax(seps)) + report.max_separation_nodes = order_pair_by_ap( + pairs[idx_max], report.ap_coords + ) + report.max_separation = seps[idx_max] + + if np.any(pair_is_distal): + distal_seps = seps[pair_is_distal] + distal_pairs_only = pairs[pair_is_distal] + idx_max_distal = int(np.argmax(distal_seps)) + report.max_separation_distal_nodes = order_pair_by_ap( + distal_pairs_only[idx_max_distal], report.ap_coords + ) + report.max_separation_distal = distal_seps[idx_max_distal] + + return pair_is_distal + + +def check_input_pair_in_valid( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + from_node: int, + to_node: int, +) -> tuple[bool, int]: + """Check whether input pair is among valid pairs. Returns (found, idx).""" + input_pair_sorted = tuple(sorted([from_node, to_node])) + input_in_valid = False + input_idx = -1 + + for k in range(len(pairs)): + if tuple(sorted(pairs[k])) == input_pair_sorted: + input_in_valid = True + input_idx = k + break + + if input_in_valid: + report.input_pair_is_distal = pair_is_distal[input_idx] + rank_order = np.argsort(seps)[::-1] + report.input_pair_rank = ( + int(np.nonzero(rank_order == input_idx)[0][0]) + 1 + ) + return input_in_valid, input_idx + + +# Scenario Assignment +# ──────────────────── + + +def assign_single_pair_scenario( + report: APNodePairReport, + pairs: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, +) -> APNodePairReport: + """Assign scenario when exactly one valid pair exists (scenarios 1-4).""" + if input_in_valid: + if pair_is_distal[0]: + report.scenario = 1 + report.outcome = "accept" + else: + report.scenario = 2 + report.outcome = "warn" + report.warning_message = "Input pair has proximal node(s)." + elif pair_is_distal[0]: + report.scenario = 3 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Suggest pair [{pairs[0, 0]}, {pairs[0, 1]}]." + ) + else: + report.scenario = 4 + report.outcome = "warn" + report.warning_message = ( + f"Input invalid. Only option " + f"[{pairs[0, 0]}, {pairs[0, 1]}] has proximal node(s)." + ) + return report + + +def assign_multi_input_distal_scenario( + report: APNodePairReport, + pairs: np.ndarray, + input_idx: int, +) -> APNodePairReport: + """Assign scenario for distal input in multi-pair case (5, 6, 7).""" + input_pair_sorted = tuple( + sorted([pairs[input_idx, 0], pairs[input_idx, 1]]) + ) + max_distal_sorted = ( + tuple(sorted(report.max_separation_distal_nodes)) + if len(report.max_separation_distal_nodes) > 0 + else () + ) + + if report.input_pair_rank == 1: + report.scenario = 5 + report.outcome = "accept" + elif input_pair_sorted == max_distal_sorted: + report.scenario = 7 + report.outcome = "accept" + else: + report.scenario = 6 + report.outcome = "warn" + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Distal pair with greater separation exists: [{d[0]}, {d[1]}]." + ) + return report + + +def assign_multi_input_proximal_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario for proximal input in multi-pair case (8-11).""" + has_distal = np.any(pair_is_distal) + is_max_sep = report.input_pair_rank == 1 + + if is_max_sep and has_distal: + report.scenario = 8 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal alternative: [{d[0]}, {d[1]}]." + ) + elif is_max_sep: + report.scenario = 9 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + elif has_distal: + report.scenario = 10 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input has proximal node(s). " + f"Distal pair with greater separation: [{d[0]}, {d[1]}]." + ) + else: + report.scenario = 11 + report.warning_message = ( + "Input has proximal node(s). All pairs have proximal node(s)." + ) + + report.outcome = "warn" + return report + + +def assign_multi_input_invalid_scenario( + report: APNodePairReport, + pair_is_distal: np.ndarray, +) -> APNodePairReport: + """Assign scenario when input not in valid pairs (12-13).""" + has_distal = np.any(pair_is_distal) + report.outcome = "warn" + + if has_distal: + report.scenario = 12 + d = report.max_separation_distal_nodes + report.warning_message = ( + f"Input invalid. Suggest max separation distal pair: " + f"[{d[0]}, {d[1]}]." + ) + else: + report.scenario = 13 + m = report.max_separation_nodes + report.warning_message = ( + f"Input invalid. All pairs have proximal node(s). " + f"Max separation: [{m[0]}, {m[1]}]." + ) + return report + + +def assign_scenario( + report: APNodePairReport, + pairs: np.ndarray, + seps: np.ndarray, + pair_is_distal: np.ndarray, + input_in_valid: bool, + input_idx: int, +) -> APNodePairReport: + """Assign one of 13 mutually exclusive scenarios. + + Parameters + ---------- + report : APNodePairReport + The report to update with scenario information. + pairs : np.ndarray + Valid pairs array of shape (n_pairs, 2). + seps : np.ndarray + Internode separations for each pair. + pair_is_distal : np.ndarray + Boolean array indicating distal pairs. + input_in_valid : bool + Whether input pair is among valid pairs. + input_idx : int + Index of input pair in valid pairs (-1 if not present). + + Returns + ------- + APNodePairReport + Updated report with scenario, outcome, and warning_message. + + """ + if len(pairs) == 1: + return assign_single_pair_scenario( + report, + pairs, + pair_is_distal, + input_in_valid, + ) + + if not input_in_valid: + return assign_multi_input_invalid_scenario(report, pair_is_distal) + + if report.input_pair_is_distal: + return assign_multi_input_distal_scenario( + report, + pairs, + input_idx, + ) + + return assign_multi_input_proximal_scenario(report, pair_is_distal) + + +def evaluate_ap_node_pair( + avg_skeleton: np.ndarray, + pc1_vec: np.ndarray, + anterior_sign: int, + valid_shape_rows: np.ndarray, + from_node: int, + to_node: int, + config: ValidateAPConfig, +) -> APNodePairReport: + """Evaluate an AP node pair through the 3-step filter cascade. + + Parameters + ---------- + avg_skeleton : np.ndarray + Average centered skeleton of shape (n_keypoints, 2). + pc1_vec : np.ndarray + First principal component vector of shape (2,). + anterior_sign : int + Inferred anterior direction (+1 or -1 relative to PC1). + valid_shape_rows : np.ndarray + Boolean array indicating valid (non-NaN) keypoints. + from_node : int + Index of the input from_node (body_axis_keypoints origin, + claimed posterior). 0-indexed. + to_node : int + Index of the input to_node (body_axis_keypoints target, + claimed anterior). 0-indexed. + config : ValidateAPConfig + Configuration with ``lateral_thresh`` and ``edge_thresh``. + + Returns + ------- + APNodePairReport + Complete evaluation report. + + """ + n_keypoints = len(avg_skeleton) + report = APNodePairReport() + report.pc1_coords = np.full(n_keypoints, np.nan) + report.ap_coords = np.full(n_keypoints, np.nan) + report.lateral_offsets = np.full(n_keypoints, np.nan) + report.lateral_offsets_norm = np.full(n_keypoints, np.nan) + report.midline_dist_norm = np.full(n_keypoints, np.nan) + + for node, label in [(from_node, "from_node"), (to_node, "to_node")]: + if node < 0 or node >= n_keypoints: + report.failure_step = "Input validation" + report.failure_reason = ( + f"{label} must be a valid index in 0..{n_keypoints - 1}." + ) + return report + + valid_idx = np.nonzero(valid_shape_rows)[0] + if len(valid_idx) < 2: + report.failure_step = "Step 1: lateral alignment filter" + report.failure_reason = "Fewer than 2 valid nodes are available." + return report + + compute_node_projections( + report, + avg_skeleton, + pc1_vec, + anterior_sign, + valid_shape_rows, + from_node, + to_node, + ) + + candidates = apply_lateral_filter(report, valid_idx, config.lateral_thresh) + if candidates is None: + return report + + step2 = find_opposite_side_pairs( + report, + candidates, + from_node, + to_node, + valid_shape_rows, + ) + if step2 is None: + return report + pairs, seps = step2 + + pair_is_distal = classify_distal_proximal( + report, + pairs, + seps, + valid_shape_rows, + config.edge_thresh, + ) + + input_in_valid, input_idx = check_input_pair_in_valid( + report, + pairs, + seps, + pair_is_distal, + from_node, + to_node, + ) + + report = assign_scenario( + report, pairs, seps, pair_is_distal, input_in_valid, input_idx + ) + report.success = True + return report + + +# Input Preparation and Validation +# ────────────────────────────────── + + +def resolve_node_index(node: Hashable, names: list) -> int: + """Resolve a node identifier to an integer index.""" + if isinstance(node, str): + if node in names: + return names.index(node) + raise ValueError(f"Keypoint '{node}' not found in {names}.") + if isinstance(node, int): + return node + return int(node) # type: ignore[call-overload] + + +def prepare_validation_inputs( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, +) -> tuple[np.ndarray, int, int, str, str, list[str], int]: + """Validate inputs and extract numpy arrays for AP validation. + + Returns + ------- + tuple + (keypoints, from_idx, to_idx, from_name, to_name, + keypoint_names, num_frames) + + Raises + ------ + TypeError + If data is not an xarray.DataArray. + ValueError + If dimensions or indices are invalid. + + """ + if not isinstance(data, xr.DataArray): + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." + ) + + required_dims = {"time", "space", "keypoints"} + if not required_dims.issubset(set(data.dims)): + raise ValueError( + f"data must have dimensions {required_dims}, " + f"but has {set(data.dims)}." + ) + + if "individuals" in data.dims: + if data.sizes["individuals"] != 1: + raise ValueError( + "data must be for a single individual. " + "Use data.sel(individuals='name') to select one." + ) + data = data.squeeze("individuals", drop=True) + + if "keypoints" in data.coords: + keypoint_names = list(data.coords["keypoints"].values) + else: + keypoint_names = [f"node_{i}" for i in range(data.sizes["keypoints"])] + + n_keypoints = data.sizes["keypoints"] + from_idx = resolve_node_index(from_node, keypoint_names) + to_idx = resolve_node_index(to_node, keypoint_names) + + if from_idx < 0 or from_idx >= n_keypoints: + raise ValueError( + f"from_node index {from_idx} out of range [0, {n_keypoints - 1}]." + ) + if to_idx < 0 or to_idx >= n_keypoints: + raise ValueError( + f"to_node index {to_idx} out of range [0, {n_keypoints - 1}]." + ) + + data_xy = data.sel(space=["x", "y"]) + keypoints = data_xy.transpose("time", "keypoints", "space").values + + from_name = keypoint_names[from_idx] + to_name = keypoint_names[to_idx] + num_frames = keypoints.shape[0] + + return ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + keypoint_names, + num_frames, + ) + + +# Pipeline Orchestration Functions +# ────────────────────────────────── + + +def run_motion_segmentation( + keypoints: np.ndarray, + num_frames: int, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run tiered validity through segment detection. + + Returns a dict with tier1_valid, tier2_valid, bbox_centroids, + segments, or None on failure (error logged). + + """ + tier1_valid, tier2_valid, _frac = compute_tiered_validity( + keypoints, config.min_valid_frac + ) + num_tier1 = int(np.sum(tier1_valid)) + num_tier2 = int(np.sum(tier2_valid)) + + log_info(_LOG_SEPARATOR) + log_info("Tiered Validity Report") + log_info(_LOG_SEPARATOR) + log_info( + "Tier 1 (>= %.0f%% keypoints): %d / %d frames (%.2f%%)", + config.min_valid_frac * 100, + num_tier1, + num_frames, + 100 * num_tier1 / num_frames, + ) + log_info( + "Tier 2 (100%% keypoints): %d / %d frames (%.2f%%)", + num_tier2, + num_frames, + 100 * num_tier2 / num_frames, + ) + + if num_tier1 < 2: + logger.error("Not enough tier-1 valid frames.") + return None + + bbox_centroids, _arith, centroid_disc = compute_bbox_centroid( + keypoints, tier1_valid + ) + valid_disc = centroid_disc[tier1_valid & ~np.isnan(centroid_disc)] + if len(valid_disc) > 0: + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Centroid Discrepancy Diagnostic") + log_info(_LOG_SEPARATOR) + log_info("BBox vs arithmetic centroid (normalized by bbox diagonal):") + log_info( + " Median: %.4f | Mean: %.4f | Max: %.4f", + np.median(valid_disc), + np.mean(valid_disc), + np.max(valid_disc), + ) + if np.median(valid_disc) > 0.05: + log_warning( + "Median discrepancy > 5%% - annotation density " + "is likely asymmetric." + ) + + segments = detect_motion_segments( + bbox_centroids, tier1_valid, config, log_info + ) + if segments is None: + return None + + return { + "tier1_valid": tier1_valid, + "tier2_valid": tier2_valid, + "bbox_centroids": bbox_centroids, + "segments": segments, + } + + +def detect_motion_segments( + bbox_centroids: np.ndarray, + tier1_valid: np.ndarray, + config: ValidateAPConfig, + log_info, +) -> np.ndarray | None: + """Detect high-motion segments from centroid velocities. + + Returns merged segments array, or None on failure. + + """ + _, speeds = compute_frame_velocities(bbox_centroids, tier1_valid) + num_speed = len(speeds) + + if num_speed < config.window_len: + logger.error( + "window_len=%d exceeds available speed samples=%d.", + config.window_len, + num_speed, + ) + return None + + window_starts, window_medians, window_all_valid = ( + compute_sliding_window_medians( + speeds, config.window_len, config.stride + ) + ) + num_valid_windows = int(np.sum(window_all_valid)) + if num_valid_windows == 0: + logger.error("No fully valid sliding windows found.") + return None + + high_motion = detect_high_motion_windows( + window_medians, window_all_valid, config.pct_thresh + ) + num_high_motion = int(np.sum(high_motion)) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("High-Motion Window Detection") + log_info(_LOG_SEPARATOR) + log_info( + "Sliding windows (len=%d, stride=%d): " + "%d total, %d fully valid (NaN-free), " + "%d high-motion (median speed >= %dth percentile)", + config.window_len, + config.stride, + len(window_starts), + num_valid_windows, + num_high_motion, + int(config.pct_thresh), + ) + + if num_high_motion == 0: + logger.error("No high-motion windows found.") + return None + + run_starts, run_ends, _run_lengths = detect_runs( + high_motion, config.min_run_len + ) + if len(run_starts) == 0: + logger.error("No runs met min_run_len=%d.", config.min_run_len) + return None + + segments_raw = convert_runs_to_segments( + run_starts, run_ends, window_starts, config.window_len + ) + segments = merge_segments(segments_raw) + + log_info("Detected %d merged high-motion segment(s):", len(segments)) + for i, (start, end) in enumerate(segments): + log_info(" Segment %d: frames %d - %d", i + 1, start, end) + + return segments + + +def select_tier2_frames( + segments: np.ndarray, + tier2_valid: np.ndarray, + num_frames: int, + log_info, + log_warning, +) -> tuple[np.ndarray, np.ndarray, int] | None: + """Filter segment frames to tier-2 valid only. + + Returns (selected_frames, selected_seg_id, num_selected) or None. + + """ + selected_frames, selected_seg_id = filter_segments_tier2( + segments, tier2_valid + ) + + num_tier1_in_segs = sum( + np.sum( + (np.arange(num_frames) >= s[0]) & (np.arange(num_frames) <= s[1]) + ) + for s in segments + ) + num_selected = len(selected_frames) + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Tier-2 Filtering on High-Motion Segments") + log_info(_LOG_SEPARATOR) + log_info( + "Frames in high-motion segments (any tier): %d", num_tier1_in_segs + ) + log_info( + "Tier-2 valid frames retained (all keypoints present): " + "%d (%.1f%% of segment frames)", + num_selected, + 100 * num_selected / max(num_tier1_in_segs, 1), + ) + + retention = num_selected / max(num_tier1_in_segs, 1) + if retention < 0.3: + log_warning( + "Tier 2 discards > 70%% of segment frames - " + "body model may be unrepresentative." + ) + + if num_selected < 2: + logger.error("Not enough tier-2 valid frames in selected segments.") + return None + + return selected_frames, selected_seg_id, num_selected + + +def run_clustering_and_pca( + centered_skeletons: np.ndarray, + frame_sel: FrameSelection, + config: ValidateAPConfig, + log_info, + log_warning, +) -> dict | None: + """Run postural analysis, clustering, and per-cluster PCA. + + Returns dict with primary_result, cluster_results, + num_clusters, primary_cluster, or None on failure. + + """ + rmsd_matrix = compute_pairwise_rmsd(centered_skeletons) + var_ratio, within_rmsds, between_rmsds, var_ratio_override = ( + compute_postural_variance_ratio(rmsd_matrix, frame_sel.seg_ids) + ) + + rmsd_stats = { + "within": within_rmsds, + "between": between_rmsds, + "var_ratio": var_ratio, + "override": var_ratio_override, + } + log_postural_consistency( + rmsd_stats, + config, + frame_sel.count, + log_info, + ) + + cluster_labels, num_clusters, primary_cluster = decide_and_run_clustering( + centered_skeletons, + var_ratio, + frame_sel.count, + config, + log_info, + ) + + cluster_results = [] + for c in range(num_clusters): + cluster_mask = cluster_labels == c + cr = compute_cluster_pca_and_anterior( + centered_skeletons, + cluster_mask, + frame_sel.frames, + frame_sel.seg_ids, + frame_sel.segments, + frame_sel.bbox_centroids, + ) + cluster_results.append(cr) + + pr = cluster_results[primary_cluster] + if not pr["valid"]: + logger.error("Primary cluster has invalid PCA result.") + return None + + return { + "primary_result": pr, + "cluster_results": cluster_results, + "num_clusters": num_clusters, + "primary_cluster": primary_cluster, + } + + +def log_postural_consistency( + rmsd_stats, + config, + num_selected, + log_info, +): + """Log postural consistency check results.""" + within_rmsds = rmsd_stats["within"] + between_rmsds = rmsd_stats["between"] + var_ratio = rmsd_stats["var_ratio"] + var_ratio_override = rmsd_stats["override"] + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Postural Consistency Check") + log_info(_LOG_SEPARATOR) + + if len(within_rmsds) > 0: + log_info( + "Within-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(within_rmsds), + np.std(within_rmsds), + len(within_rmsds), + ) + else: + log_info("Within-segment RMSD: N/A (no within-segment pairs)") + + if len(between_rmsds) > 0: + log_info( + "Between-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", + np.mean(between_rmsds), + np.std(between_rmsds), + len(between_rmsds), + ) + log_info( + "Variance ratio (between/within): %.2f (threshold=%.2f)", + var_ratio, + config.postural_var_ratio_thresh, + ) + if var_ratio_override: + log_info( + " (Conservative override to zero: within-segment variance " + "is zero or no within-segment pairs)" + ) + else: + log_info("Between-segment RMSD: N/A (single segment)") + log_info("Variance ratio: N/A") + + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + if do_clustering: + log_info(" -> Variance ratio exceeds threshold. Running clustering.") + elif var_ratio > config.postural_var_ratio_thresh and num_selected < 6: + log_info( + " -> Variance ratio exceeds threshold but too few frames (%d) " + "for clustering.", + num_selected, + ) + else: + log_info(" -> Postural consistency acceptable. Using global average.") + + +def decide_and_run_clustering( + centered_skeletons, + var_ratio, + num_selected, + config, + log_info, +): + """Decide whether to cluster; run k-medoids if triggered.""" + do_clustering = ( + var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 + ) + + if not do_clustering: + return np.zeros(num_selected, dtype=int), 1, 0 + + ( + cluster_labels, + num_clusters, + primary_cluster, + best_silhouette, + silhouette_scores, + ) = perform_postural_clustering(centered_skeletons, config.max_clusters) + + for k, sil in silhouette_scores: + if np.isnan(sil): + log_info(" k=%d: clustering failed.", k) + else: + log_info(" k=%d: mean silhouette = %.4f", k, sil) + + if num_clusters > 1: + cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) + log_info( + " Selected k=%d clusters (silhouette=%.4f). " + "Primary cluster=%d (%d frames)", + num_clusters, + best_silhouette, + primary_cluster + 1, + cluster_counts[primary_cluster], + ) + else: + log_info( + " Clustering did not improve separation (best_sil=%.4f). " + "Using global average.", + best_silhouette, + ) + + return cluster_labels, num_clusters, primary_cluster + + +def log_anterior_report( + pr, + cluster_results, + num_clusters, + primary_cluster, + config, + log_info, + log_warning, +): + """Log anterior direction detection and cluster agreement.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Anterior Direction Inference (Velocity Voting)") + log_info(_LOG_SEPARATOR) + log_info( + "Centroid velocity projections onto PC1: " + "%d positive (+PC1), %d negative (-PC1)", + pr["num_positive"], + pr["num_negative"], + ) + + vote_margin_str = f"Vote margin M: {pr['vote_margin']:.4f}" + if pr["vote_margin"] < config.confidence_floor: + vote_margin_str += ( + f" ** BELOW CONFIDENCE FLOOR ({config.confidence_floor:.2f}) " + "- anterior assignment is unreliable **" + ) + log_warning( + "Vote margin M = %.4f is below confidence floor %.2f - " + "anterior assignment is unreliable.", + pr["vote_margin"], + config.confidence_floor, + ) + log_info(vote_margin_str) + log_info( + "Resultant length R: %.4f (0 = omnidirectional, 1 = unidirectional)", + pr["resultant_length"], + ) + log_info( + "Inferred anterior direction: %sPC1 " + "(strict majority; ties default to -PC1)", + "+" if pr["anterior_sign"] > 0 else "-", + ) + + if num_clusters <= 1: + return + + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Inter-Cluster Anterior Polarity Agreement") + log_info(_LOG_SEPARATOR) + signs = [cr["anterior_sign"] for cr in cluster_results if cr["valid"]] + if len(set(signs)) == 1: + log_info( + "All %d clusters AGREE on anterior polarity.", + num_clusters, + ) + else: + log_info( + "DISAGREEMENT: clusters assign different anterior polarities." + ) + for c, cr in enumerate(cluster_results): + if cr["valid"]: + log_info( + " Cluster %d (%d frames): anterior = %sPC1, " + "vote_margin M = %.4f, resultant_length R = %.4f", + c + 1, + cr["n_frames"], + "+" if cr["anterior_sign"] > 0 else "-", + cr["vote_margin"], + cr["resultant_length"], + ) + log_info( + " Primary result from cluster %d (largest).", + primary_cluster + 1, + ) + + +def log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + log_info, +): + """Log the complete AP node pair evaluation report.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("AP Node-Pair Filter Cascade (3-Step Evaluation)") + log_info(_LOG_SEPARATOR) + log_info( + "Input pair: [%d, %d] (%s -> %s, claimed posterior -> anterior)", + from_idx, + to_idx, + from_name, + to_name, + ) + + step1_failed = pair_report.failure_step.startswith("Step 1") + + valid_nodes = np.nonzero(~np.isnan(pair_report.lateral_offsets_norm))[0] + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(len(valid_nodes), 1) + + log_step1_report(pair_report, config, valid_nodes, log_info) + log_input_node_status(pair_report, config, from_idx, to_idx, log_info) + + step2_loss = 0.0 + step3_frac = 0.0 + step2_failed = False + + if step1_failed: + log_info("") + log_info("Step 2-3: not evaluated (Step 1 failed)") + else: + step2_loss, step3_frac, step2_failed = log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, + ) + + log_loss_summary( + step1_loss, + step2_loss, + step3_frac, + step1_failed, + step2_failed, + log_info, + ) + log_order_check( + pair_report, + from_idx, + to_idx, + from_name, + to_name, + log_info, + ) + + +def log_step1_report(pair_report, config, valid_nodes, log_info): + """Log Step 1 lateral filter results.""" + num_valid = len(valid_nodes) + num_candidates = len(pair_report.sorted_candidate_nodes) + step1_loss = 1 - num_candidates / max(num_valid, 1) + + pass_strs = [] + fail_strs = [] + for node_i in valid_nodes: + lat_norm = pair_report.lateral_offsets_norm[node_i] + if lat_norm <= config.lateral_thresh: + pass_strs.append(f"{node_i}({lat_norm:.2f})") + else: + fail_strs.append(f"{node_i}({lat_norm:.2f})") + + log_info("") + log_info( + "Step 1 - Lateral Alignment Filter (lateral_thresh=%.2f): " + "%d of %d valid nodes pass [loss=%.0f%%]", + config.lateral_thresh, + num_candidates, + num_valid, + 100 * step1_loss, + ) + log_info( + " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" + ) + if pass_strs: + log_info(" PASS: %s", ", ".join(pass_strs)) + if fail_strs: + log_info(" FAIL: %s", ", ".join(fail_strs)) + + +def log_step2_report(pair_report, _config, log_info): + """Log Step 2 opposite-sides results.""" + num_candidates = len(pair_report.sorted_candidate_nodes) + num_possible_pairs = num_candidates * (num_candidates - 1) // 2 + num_valid_pairs = len(pair_report.valid_pairs) + step2_loss = 1 - num_valid_pairs / max(num_possible_pairs, 1) + m = pair_report.midpoint_pc1 + + plus_strs = [] + minus_strs = [] + for node_i in pair_report.sorted_candidate_nodes: + pc1_rel = pair_report.pc1_coords[node_i] - m + if pc1_rel > 0: + plus_strs.append(f"{node_i}({pc1_rel:+.1f})") + else: + minus_strs.append(f"{node_i}({pc1_rel:+.1f})") + + log_info("") + log_info( + "Step 2 - Opposite-Sides Constraint (AP midpoint=%.2f): " + "%d of %d candidate pairs on opposite sides [loss=%.0f%%]", + m, + num_valid_pairs, + num_possible_pairs, + 100 * step2_loss, + ) + if plus_strs: + log_info(" + side (anterior of midpoint): %s", ", ".join(plus_strs)) + if minus_strs: + log_info(" - side (posterior of midpoint): %s", ", ".join(minus_strs)) + + +def log_step3_report(pair_report, config, log_info): + """Log Step 3 distal/proximal classification results.""" + num_distal = len(pair_report.distal_pairs) + num_proximal = len(pair_report.proximal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_distal_frac = num_distal / max(num_valid_pairs, 1) + + log_info("") + log_info( + "Step 3 - Distal/Proximal Classification (edge_thresh=%.2f): " + "%d distal, %d proximal [distal fraction=%.0f%%]", + config.edge_thresh, + num_distal, + num_proximal, + 100 * step3_distal_frac, + ) + + for idx in range(num_valid_pairs): + node_i, node_j = pair_report.valid_pairs[idx] + d_i = pair_report.midline_dist_norm[node_i] + d_j = pair_report.midline_dist_norm[node_j] + min_d = min(d_i, d_j) + sep = pair_report.valid_pairs_internode_dist[idx] + status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" + log_info( + " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", + node_i, + node_j, + min_d, + sep, + status, + ) + + +def log_loss_summary( + step1_loss, step2_loss, step3_frac, step1_failed, step2_failed, log_info +): + """Log cumulative filtering loss summary.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Filtering Loss Summary") + log_info(_LOG_SEPARATOR) + log_info( + "Step 1 (Lateral Filter): %.0f%% of valid nodes eliminated", + 100 * step1_loss, + ) + if not step1_failed: + log_info( + "Step 2 (Opposite-Sides): %.0f%% of candidate pairs eliminated", + 100 * step2_loss, + ) + if not step1_failed and not step2_failed: + log_info( + "Step 3 (Distal/Proximal): %.0f%% of surviving pairs are distal", + 100 * step3_frac, + ) + + +def log_order_check( + pair_report, from_idx, to_idx, from_name, to_name, log_info +): + """Log AP ordering check for the input pair.""" + log_info("") + log_info(_LOG_SEPARATOR) + log_info("Order Check: is from_node posterior to to_node?") + log_info(_LOG_SEPARATOR) + ap_from = pair_report.ap_coords[from_idx] + ap_to = pair_report.ap_coords[to_idx] + if np.isnan(ap_from) or np.isnan(ap_to): + log_info("Order check: cannot evaluate (invalid node coordinates)") + return + + log_info( + "AP coords: from_node %s[%d]=%.2f, to_node %s[%d]=%.2f", + from_name, + from_idx, + ap_from, + to_name, + to_idx, + ap_to, + ) + if pair_report.input_pair_order_matches_inference: + log_info( + "[%d, %d]: CONSISTENT - inference agrees that " + "from_node is posterior (lower AP coord), " + "to_node is anterior", + from_idx, + to_idx, + ) + else: + log_info( + "[%d, %d]: INCONSISTENT - inference suggests " + "from_node is anterior (higher AP coord), " + "to_node is posterior", + from_idx, + to_idx, + ) + log_info( + " -> Inferred posterior->anterior order would be [%d, %d]", + to_idx, + from_idx, + ) + + +def log_input_node_status( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log whether each input node passed the lateral filter.""" + lat_from = pair_report.lateral_offsets_norm[from_idx] + lat_to = pair_report.lateral_offsets_norm[to_idx] + from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh + to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh + + if from_pass and to_pass: + return + + fail_nodes = [] + if not from_pass: + fail_nodes.append(f"{from_idx}({lat_from:.2f})") + if not to_pass: + fail_nodes.append(f"{to_idx}({lat_to:.2f})") + log_info( + " -> Input node(s) FAILED lateral filter: %s", + ", ".join(fail_nodes), + ) + + +def log_step2_step3_details( + pair_report, + config, + from_idx, + to_idx, + num_candidates, + log_info, +): + """Log Step 2 and Step 3 results when Step 1 succeeded. + + Returns (step2_loss, step3_frac, step2_failed). + + """ + log_step2_report(pair_report, config, log_info) + + if ( + pair_report.input_pair_in_candidates + and not pair_report.input_pair_opposite_sides + ): + log_info(" -> Input nodes on SAME side of AP midpoint") + + num_possible = num_candidates * (num_candidates - 1) // 2 + step2_loss = 1 - len(pair_report.valid_pairs) / max(num_possible, 1) + step2_failed = pair_report.failure_step.startswith("Step 2") + + if step2_failed: + log_info("") + log_info("Step 3: not evaluated (Step 2 failed)") + return step2_loss, 0.0, True + + step3_frac = log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, + ) + return step2_loss, step3_frac, False + + +def log_step3_with_proximal_check( + pair_report, + config, + from_idx, + to_idx, + log_info, +): + """Log Step 3 results and check input pair proximal status. + + Returns step3_frac. + + """ + log_step3_report(pair_report, config, log_info) + num_distal = len(pair_report.distal_pairs) + num_valid_pairs = len(pair_report.valid_pairs) + step3_frac = num_distal / max(num_valid_pairs, 1) + + is_candidate = pair_report.input_pair_in_candidates + is_opposite = pair_report.input_pair_opposite_sides + is_proximal = not pair_report.input_pair_is_distal + if is_candidate and is_opposite and is_proximal: + d_from = pair_report.midline_dist_norm[from_idx] + d_to = pair_report.midline_dist_norm[to_idx] + log_info( + " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", + min(d_from, d_to), + config.edge_thresh, + ) + + return step3_frac + + +# Main Validation Function +# ───────────────────────── + + +def validate_ap( + data: xr.DataArray, + from_node: Hashable, + to_node: Hashable, + config: ValidateAPConfig | None = None, + verbose: bool = True, +) -> dict: + """Validate an anterior-posterior keypoint pair using body-axis inference. + + This function implements a prior-free body-axis inference pipeline that: + 1. Identifies high-motion segments using tiered validity and sliding + windows + 2. Optionally performs postural clustering via k-medoids + 3. Infers the anterior direction using velocity projection voting + 4. Evaluates the candidate AP keypoint pair through a 3-step filter + cascade + + Parameters + ---------- + data : xarray.DataArray + Position data for a single individual. + from_node : int or str + Index or name of the posterior keypoint. + to_node : int or str + Index or name of the anterior keypoint. + config : ValidateAPConfig, optional + Configuration parameters. If None, uses defaults. + verbose : bool, default=True + If True, log detailed validation output to console. + + Returns + ------- + dict + Validation results including success, anterior_sign, + vote_margin, resultant_length, pair_report, etc. + + """ + if config is None: + config = ValidateAPConfig() + + log_lines: list[str] = [] + + def _log_info(msg, *args): + """Log an informational message.""" + line = msg % args if args else msg + log_lines.append(line) + if verbose: + print(line) + + def _log_warning(msg, *args): + """Log a warning message.""" + line = f"WARNING: {msg % args if args else msg}" + log_lines.append(line) + if verbose: + print(line) + + # Prepare inputs + ( + keypoints, + from_idx, + to_idx, + from_name, + to_name, + _keypoint_names, + num_frames, + ) = prepare_validation_inputs(data, from_node, to_node) + + n_keypoints = keypoints.shape[1] + result: dict = { + "success": False, + "anterior_sign": 0, + "vote_margin": 0.0, + "resultant_length": 0.0, + "num_selected_frames": 0, + "num_clusters": 1, + "primary_cluster": 0, + "pair_report": APNodePairReport(), + "PC1": np.array([1.0, 0.0]), + "PC2": np.array([0.0, 1.0]), + "avg_skeleton": np.full((n_keypoints, 2), np.nan), + "error_msg": "", + "log_lines": log_lines, + } + + # Motion segmentation + seg = run_motion_segmentation( + keypoints, + num_frames, + config, + _log_info, + _log_warning, + ) + if seg is None: + result["error_msg"] = "Motion segmentation failed." + return result + + # Tier-2 frame selection + t2 = select_tier2_frames( + seg["segments"], + seg["tier2_valid"], + num_frames, + _log_info, + _log_warning, + ) + if t2 is None: + result["error_msg"] = "Not enough tier-2 valid frames." + return result + selected_frames, selected_seg_id, num_selected = t2 + result["num_selected_frames"] = num_selected + + # Build centered skeletons + _selected_centroids, centered_skeletons = build_centered_skeletons( + keypoints, selected_frames + ) + + # Bundle frame selection data + frame_sel = FrameSelection( + frames=selected_frames, + seg_ids=selected_seg_id, + segments=seg["segments"], + bbox_centroids=seg["bbox_centroids"], + count=num_selected, + ) + + # Postural clustering + PCA + anterior inference + pca = run_clustering_and_pca( + centered_skeletons, + frame_sel, + config, + _log_info, + _log_warning, + ) + if pca is None: + result["error_msg"] = "Primary cluster PCA failed." + return result + + pr = pca["primary_result"] + result["anterior_sign"] = pr["anterior_sign"] + result["vote_margin"] = pr["vote_margin"] + result["resultant_length"] = pr["resultant_length"] + result["circ_mean_dir"] = pr["circ_mean_dir"] + result["vel_projs_pc1"] = pr["vel_projs_pc1"] + result["PC1"] = pr["PC1"] + result["PC2"] = pr["PC2"] + result["avg_skeleton"] = pr["avg_skeleton"] + result["num_clusters"] = pca["num_clusters"] + result["primary_cluster"] = pca["primary_cluster"] + + # Log anterior inference + log_anterior_report( + pr, + pca["cluster_results"], + pca["num_clusters"], + pca["primary_cluster"], + config, + _log_info, + _log_warning, + ) + + # AP node-pair evaluation + pair_report = evaluate_ap_node_pair( + pr["avg_skeleton"], + pr["PC1"], + pr["anterior_sign"], + pr["valid_shape_rows"], + from_idx, + to_idx, + config, + ) + result["pair_report"] = pair_report + + log_pair_evaluation( + pair_report, + config, + from_idx, + to_idx, + from_name, + to_name, + _log_info, + ) + + result["success"] = True + return result + + +# Multi-Individual Validation +# ──────────────────────────── + + +def run_ap_validation( + data: xr.DataArray, + normalized_keypoints: tuple[Hashable, Hashable], + ap_validation_config: dict[str, Any] | None, +) -> dict: + """Run AP validation across all individuals, select best by R*M. + + Each individual is validated independently using the supplied keypoint + pair. R*M (resultant_length * vote_margin) is computed per individual + and depends only on the individual's motion and body shape, not on + the input pair. The best individual is the one with the highest R*M. + + Parameters + ---------- + data : xarray.DataArray + Position data with individuals dimension. + normalized_keypoints : tuple[Hashable, Hashable] + The (from_node, to_node) keypoint pair. + ap_validation_config : dict, optional + Configuration overrides for ValidateAPConfig. + + Returns + ------- + dict + Dictionary with 'all_results' (list of per-individual results) + and 'best_idx' (index of best individual by R*M). + + """ + config = ( + ValidateAPConfig(**ap_validation_config) + if ap_validation_config is not None + else None + ) + + if "individuals" not in data.dims: + single_result = validate_ap( + data, + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + return {"all_results": [single_result], "best_idx": 0} + + individuals = list(data.coords["individuals"].values) + all_results = [] + for individual in individuals: + result = validate_ap( + data.sel(individuals=individual), + from_node=normalized_keypoints[0], + to_node=normalized_keypoints[1], + config=config, + verbose=False, + ) + result["individual"] = individual + all_results.append(result) + + best_idx = find_best_individual_by_rxm(all_results) + return {"all_results": all_results, "best_idx": best_idx} + + +def find_best_individual_by_rxm(all_results: list[dict]) -> int: + """Return index of the individual with highest R*M score.""" + best_idx = -1 + best_rxm = -1.0 + for i, result in enumerate(all_results): + if not result["success"]: + continue + rxm = result["resultant_length"] * result["vote_margin"] + if rxm > best_rxm: + best_rxm = rxm + best_idx = i + return best_idx diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py index 3cb0b8f3c..e9795d3fc 100644 --- a/movement/kinematics/collective.py +++ b/movement/kinematics/collective.py @@ -2,12 +2,12 @@ """Compute collective behavior metrics for multi-individual tracking data.""" from collections.abc import Hashable -from dataclasses import dataclass, field from typing import Any import numpy as np import xarray as xr +from movement.kinematics.body_axis import run_ap_validation from movement.utils.logging import logger from movement.utils.vector import ( compute_norm, @@ -19,265 +19,13 @@ _ANGLE_EPS = 1e-12 -@dataclass -class _ValidateAPConfig: - """Configuration for the _validate_ap function. - - Parameters - ---------- - min_valid_frac : float, default=0.6 - Minimum fraction of keypoints that must be present for a frame - to qualify as tier-1 valid. - window_len : int, default=50 - Number of speed samples per sliding window. - stride : int, default=5 - Step size between consecutive sliding window start positions. - pct_thresh : float, default=85.0 - Percentile threshold applied to valid-window median speeds for - high-motion classification. - min_run_len : int, default=1 - Minimum number of consecutive qualifying windows required to - form a valid run. - postural_var_ratio_thresh : float, default=2.0 - Between-segment to within-segment RMSD variance ratio above which - postural clustering is triggered. - max_clusters : int, default=4 - Upper bound on the number of clusters to evaluate during k-medoids. - confidence_floor : float, default=0.1 - Vote margin below which the anterior inference is flagged as - unreliable. - lateral_thresh : float, default=0.4 - Normalized lateral offset ceiling for the Step 1 lateral alignment - filter. - edge_thresh : float, default=0.3 - Normalized midpoint distance floor for the Step 3 distal/proximal - classification. - - """ - - min_valid_frac: float = 0.6 - window_len: int = 50 - stride: int = 5 - pct_thresh: float = 85.0 - min_run_len: int = 1 - postural_var_ratio_thresh: float = 2.0 - max_clusters: int = 4 - confidence_floor: float = 0.1 - lateral_thresh: float = 0.4 - edge_thresh: float = 0.3 - - def __post_init__(self) -> None: - """Validate configuration parameters.""" - # Validate fraction parameters (must be in [0, 1]) - for name in ( - "min_valid_frac", - "confidence_floor", - "lateral_thresh", - "edge_thresh", - ): - value = getattr(self, name) - if not (0 <= value <= 1): - raise ValueError( - f"{name} must be between 0 and 1, got {value}" - ) - - # Validate positive integer parameters - for name in ("window_len", "stride", "min_run_len", "max_clusters"): - value = getattr(self, name) - if not isinstance(value, int) or value <= 0: - raise ValueError( - f"{name} must be a positive integer, got {value}" - ) - - # Validate pct_thresh (must be in [0, 100]) - if not (0 <= self.pct_thresh <= 100): - raise ValueError( - f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" - ) - - # Validate postural_var_ratio_thresh (must be positive) - if self.postural_var_ratio_thresh <= 0: - raise ValueError( - f"postural_var_ratio_thresh must be positive, " - f"got {self.postural_var_ratio_thresh}" - ) - - -@dataclass -class _FrameSelection: - """Selected frames from high-motion segmentation and tier-2 filtering. - - Bundles the frame indices, segment assignments, and related arrays - produced by the segmentation pipeline for downstream consumption - (skeleton construction, postural clustering, velocity recomputation). - - Attributes - ---------- - frames : np.ndarray - Array of selected frame indices (tier-2 valid, within segments). - seg_ids : np.ndarray - Segment ID (0-indexed) for each selected frame. - segments : np.ndarray - Array of shape (n_segments, 2) with [frame_start, frame_end]. - bbox_centroids : np.ndarray - Array of shape (n_frames, 2) with bounding-box centroids. - count : int - Number of selected frames. - - """ - - frames: np.ndarray - seg_ids: np.ndarray - segments: np.ndarray - bbox_centroids: np.ndarray - count: int - - -@dataclass -class _APNodePairReport: - """Report from the AP node-pair evaluation pipeline. - - This dataclass holds all results from the 3-step filter cascade - used to evaluate a candidate anterior-posterior keypoint pair. - - Attributes - ---------- - success : bool - Whether the evaluation pipeline completed successfully. - failure_step : str - Name of the step at which evaluation failed, if any. - failure_reason : str - Reason for failure, if any. - scenario : int - Scenario number (1-13) from the mutually exclusive outcomes. - outcome : str - Either "accept" or "warn". - warning_message : str - Warning message, if applicable. - sorted_candidate_nodes : np.ndarray - Indices of candidate nodes after Step 1 filtering, sorted by - ascending normalized lateral offset. - valid_pairs : np.ndarray - Array of shape (n_pairs, 2) containing valid node pairs after - Step 2 filtering. - valid_pairs_internode_dist : np.ndarray - Internode separation (AP distance) for each valid pair. - input_pair_in_candidates : bool - Whether the input pair survived Step 1 filtering. - input_pair_opposite_sides : bool - Whether the input pair lies on opposite sides of the midpoint. - input_pair_separation_abs : float - Absolute AP separation of the input pair. - input_pair_is_distal : bool - Whether the input pair is classified as distal in Step 3. - input_pair_rank : int - Rank of the input pair by internode separation (1 = largest). - input_pair_order_matches_inference : bool - Whether from_node has a lower AP coordinate than to_node - (i.e. from_node is more posterior). True means the input pair - ordering is consistent with the inferred AP axis. - pc1_coords : np.ndarray - PC1 coordinates for each keypoint. - ap_coords : np.ndarray - AP (anterior-posterior) coordinates for each keypoint. - lateral_offsets : np.ndarray - Unsigned lateral offset from body axis for each keypoint. - lateral_offsets_norm : np.ndarray - Normalized lateral offsets (0 = nearest to axis, 1 = farthest). - lateral_offset_min : float - Minimum lateral offset among valid keypoints. - lateral_offset_max : float - Maximum lateral offset among valid keypoints. - midpoint_pc1 : float - AP reference midpoint (average of min and max PC1 projections). - pc1_min : float - Minimum PC1 projection among valid keypoints. - pc1_max : float - Maximum PC1 projection among valid keypoints. - midline_dist_norm : np.ndarray - Normalized distance from midpoint for each keypoint. - midline_dist_max : float - Maximum absolute distance from midpoint. - distal_pairs : np.ndarray - Array of distal pairs (both nodes at or above edge_thresh). - proximal_pairs : np.ndarray - Array of proximal pairs (at least one node below edge_thresh). - max_separation_distal_nodes : np.ndarray - Node indices of the maximum-separation distal pair, ordered - so that element 0 is posterior (lower AP coord) and element 1 - is anterior (higher AP coord). - max_separation_distal : float - Internode separation of the max-separation distal pair. - max_separation_nodes : np.ndarray - Node indices of the overall maximum-separation pair, ordered - so that element 0 is posterior (lower AP coord) and element 1 - is anterior (higher AP coord). - max_separation : float - Internode separation of the overall max-separation pair. - - """ - - success: bool = False - failure_step: str = "" - failure_reason: str = "" - scenario: int = 0 - outcome: str = "" - warning_message: str = "" - - sorted_candidate_nodes: np.ndarray = field( - default_factory=lambda: np.array([], dtype=int) - ) - valid_pairs: np.ndarray = field( - default_factory=lambda: np.zeros((0, 2), dtype=int) - ) - valid_pairs_internode_dist: np.ndarray = field( - default_factory=lambda: np.array([]) - ) - - input_pair_in_candidates: bool = False - input_pair_opposite_sides: bool = False - input_pair_separation_abs: float = np.nan - input_pair_is_distal: bool = False - input_pair_rank: int = 0 - input_pair_order_matches_inference: bool = False - - pc1_coords: np.ndarray = field(default_factory=lambda: np.array([])) - ap_coords: np.ndarray = field(default_factory=lambda: np.array([])) - lateral_offsets: np.ndarray = field(default_factory=lambda: np.array([])) - lateral_offsets_norm: np.ndarray = field( - default_factory=lambda: np.array([]) - ) - lateral_offset_min: float = np.nan - lateral_offset_max: float = np.nan - midpoint_pc1: float = np.nan - pc1_min: float = np.nan - pc1_max: float = np.nan - midline_dist_norm: np.ndarray = field(default_factory=lambda: np.array([])) - midline_dist_max: float = np.nan - - distal_pairs: np.ndarray = field( - default_factory=lambda: np.zeros((0, 2), dtype=int) - ) - proximal_pairs: np.ndarray = field( - default_factory=lambda: np.zeros((0, 2), dtype=int) - ) - max_separation_distal_nodes: np.ndarray = field( - default_factory=lambda: np.array([], dtype=int) - ) - max_separation_distal: float = np.nan - max_separation_nodes: np.ndarray = field( - default_factory=lambda: np.array([], dtype=int) - ) - max_separation: float = np.nan - - def compute_polarization( data: xr.DataArray, body_axis_keypoints: tuple[Hashable, Hashable] | None = None, displacement_frames: int = 1, return_angle: bool = False, in_degrees: bool = False, - validate_ap: bool = True, + validate_ap: bool = False, ap_validation_config: dict[str, Any] | None = None, ) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: r"""Compute polarization (group alignment) of individuals. @@ -325,13 +73,13 @@ def compute_polarization( If True, the mean angle is returned in degrees. Otherwise, the angle is returned in radians. Only relevant when ``return_angle=True``. - validate_ap : bool, default=True + validate_ap : bool, default=False If True, run anterior-posterior axis validation when ``body_axis_keypoints`` is provided. Validation is skipped for displacement-based polarization. ap_validation_config : dict, optional Configuration overrides for anterior-posterior axis validation. - Passed to ``_ValidateAPConfig`` when validation is enabled. + See ``movement.kinematics.body_axis.ValidateAPConfig`` for options. Returns ------- @@ -428,7 +176,7 @@ def compute_polarization( ap_validation_result = None if normalized_keypoints is not None: if validate_ap: - ap_validation_result = _run_ap_validation( + ap_validation_result = run_ap_validation( data, normalized_keypoints, ap_validation_config ) heading_vectors = _compute_heading_from_keypoints( @@ -441,26 +189,10 @@ def compute_polarization( displacement_frames=displacement_frames, ) - polarization = _compute_polarization_from_headings(heading_vectors) - - if ap_validation_result is not None: - polarization.attrs["ap_validation_result"] = ap_validation_result - - if not return_angle: - return polarization - - mean_angle = _compute_mean_angle(heading_vectors, in_degrees) - return polarization, mean_angle - - -def _compute_polarization_from_headings( - heading_vectors: xr.DataArray, -) -> xr.DataArray: - """Compute polarization magnitude from heading vectors.""" heading = _select_space(heading_vectors) + unit_headings = convert_to_unit(heading) valid_mask = ~unit_headings.isnull().any(dim="space") - vector_sum = unit_headings.sum(dim="individuals", skipna=True) sum_magnitude = compute_norm(vector_sum) n_valid = valid_mask.sum(dim="individuals") @@ -470,23 +202,18 @@ def _compute_polarization_from_headings( sum_magnitude / n_valid, np.nan, ).clip(min=0.0, max=1.0) - return polarization.rename("polarization") + polarization = polarization.rename("polarization") + if ap_validation_result is not None: + polarization.attrs["ap_validation_result"] = ap_validation_result -def _compute_mean_angle( - heading_vectors: xr.DataArray, - in_degrees: bool = False, -) -> xr.DataArray: - """Compute mean heading angle from heading vectors.""" - heading = _select_space(heading_vectors) - unit_headings = convert_to_unit(heading) - valid_mask = ~unit_headings.isnull().any(dim="space") - - vector_sum = unit_headings.sum(dim="individuals", skipna=True) - sum_magnitude = compute_norm(vector_sum) - n_valid = valid_mask.sum(dim="individuals") + if not return_angle: + return polarization + # Normalize vector_sum to unit vector for angle computation mean_unit_vector = vector_sum / sum_magnitude + + # Compute angle from positive x-axis to mean unit vector reference = np.array([1, 0]) angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) mean_angle = xr.where( @@ -498,66 +225,9 @@ def _compute_mean_angle( ) if in_degrees: mean_angle = np.rad2deg(mean_angle) - return mean_angle.rename("mean_angle") - - -def _run_ap_validation( - data: xr.DataArray, - normalized_keypoints: tuple[Hashable, Hashable], - ap_validation_config: dict[str, Any] | None, -) -> dict: - """Run AP validation across all individuals, select best by R*M. - - Each individual is validated independently using the supplied keypoint - pair. R*M (resultant_length × vote_margin) is computed per individual - and depends only on the individual's motion and body shape, not on - the input pair. The best individual is the one with the highest R*M. - """ - config = ( - _ValidateAPConfig(**ap_validation_config) - if ap_validation_config is not None - else None - ) - - if "individuals" not in data.dims: - single_result = _validate_ap( - data, - from_node=normalized_keypoints[0], - to_node=normalized_keypoints[1], - config=config, - verbose=False, - ) - return {"all_results": [single_result], "best_idx": 0} - - individuals = list(data.coords["individuals"].values) - all_results = [] - for individual in individuals: - result = _validate_ap( - data.sel(individuals=individual), - from_node=normalized_keypoints[0], - to_node=normalized_keypoints[1], - config=config, - verbose=False, - ) - result["individual"] = individual - all_results.append(result) - - best_idx = _find_best_individual_by_rxm(all_results) - return {"all_results": all_results, "best_idx": best_idx} - + mean_angle = mean_angle.rename("mean_angle") -def _find_best_individual_by_rxm(all_results: list[dict]) -> int: - """Return index of the individual with highest R*M score.""" - best_idx = -1 - best_rxm = -1.0 - for i, result in enumerate(all_results): - if not result["success"]: - continue - rxm = result["resultant_length"] * result["vote_margin"] - if rxm > best_rxm: - best_rxm = rxm - best_idx = i - return best_idx + return polarization, mean_angle def _compute_heading_from_keypoints( @@ -603,8 +273,9 @@ def _compute_heading_from_velocity( def _select_space(data: xr.DataArray) -> xr.DataArray: - """Return data with standard dim order, preserving all spatial coords.""" - return data.transpose("time", "space", "individuals") + """Return data with standard dim order, selecting only x and y coords.""" + result = data.sel(space=["x", "y"]) + return result.transpose("time", "space", "individuals") def _validate_position_data( @@ -702,2547 +373,3 @@ def _validate_type_data_array(data: xr.DataArray) -> None: raise TypeError( f"Input data must be an xarray.DataArray, but got {type(data)}." ) - - -# Helper functions for _validate_ap - - -def _compute_tiered_validity( - keypoints: np.ndarray, - min_valid_frac: float, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Compute tiered validity masks for each frame. - - Parameters - ---------- - keypoints : np.ndarray - Keypoint positions with shape (n_frames, n_keypoints, 2). - min_valid_frac : float - Minimum fraction of keypoints required for tier-1 validity. - - Returns - ------- - tier1_valid : np.ndarray - Boolean array of shape (n_frames,) indicating tier-1 valid frames. - A frame is tier-1 valid if at least min_valid_frac of keypoints - are present AND at least 2 keypoints are present. - tier2_valid : np.ndarray - Boolean array of shape (n_frames,) indicating tier-2 valid frames. - A frame is tier-2 valid if all keypoints are present. - frac_present : np.ndarray - Array of shape (n_frames,) with fraction of keypoints present. - - """ - n_frames, n_keypoints, _ = keypoints.shape - - # A keypoint is present if neither x nor y is NaN - # Shape: (n_frames, n_keypoints) - keypoint_present = ~np.any(np.isnan(keypoints), axis=2) - - # Count present keypoints per frame - n_present = np.sum(keypoint_present, axis=1) - frac_present = n_present / n_keypoints - - # Tier-2: all keypoints present - tier2_valid = n_present == n_keypoints - - # Tier-1: at least min_valid_frac present AND at least 2 present - tier1_valid = (frac_present >= min_valid_frac) & (n_present >= 2) - - return tier1_valid, tier2_valid, frac_present - - -def _compute_bbox_centroid( - keypoints: np.ndarray, - tier1_valid: np.ndarray, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Compute bounding-box centroids for tier-1 valid frames. - - The bounding-box centroid is the midpoint of the axis-aligned bounding - box enclosing all present keypoints. This is density-invariant, unlike - the arithmetic mean. - - Parameters - ---------- - keypoints : np.ndarray - Keypoint positions with shape (n_frames, n_keypoints, 2). - tier1_valid : np.ndarray - Boolean array of shape (n_frames,) indicating tier-1 valid frames. - - Returns - ------- - bbox_centroids : np.ndarray - Array of shape (n_frames, 2) with bounding-box centroids. - NaN for non-tier-1-valid frames. - arith_centroids : np.ndarray - Array of shape (n_frames, 2) with arithmetic-mean centroids. - NaN for non-tier-1-valid frames. Used for diagnostic comparison. - centroid_discrepancy : np.ndarray - Array of shape (n_frames,) with normalized discrepancy between - bbox and arithmetic centroids (distance / bbox_diagonal). - NaN for non-tier-1-valid frames. - - """ - n_frames = keypoints.shape[0] - - bbox_centroids = np.full((n_frames, 2), np.nan) - arith_centroids = np.full((n_frames, 2), np.nan) - centroid_discrepancy = np.full(n_frames, np.nan) - - for f in range(n_frames): - if not tier1_valid[f]: - continue - - kp_f = keypoints[f] # (n_keypoints, 2) - - # Find present keypoints (no NaN in either coordinate) - present_mask = ~np.any(np.isnan(kp_f), axis=1) - kp_present = kp_f[present_mask] - - # Bounding-box centroid - bbox_min = np.min(kp_present, axis=0) - bbox_max = np.max(kp_present, axis=0) - bbox_centroids[f] = (bbox_min + bbox_max) / 2 - - # Arithmetic-mean centroid - arith_centroids[f] = np.mean(kp_present, axis=0) - - # Centroid discrepancy: distance normalized by bbox diagonal - bbox_diag = np.linalg.norm(bbox_max - bbox_min) - if bbox_diag > 0: - discrepancy = np.linalg.norm( - bbox_centroids[f] - arith_centroids[f] - ) - centroid_discrepancy[f] = discrepancy / bbox_diag - else: - centroid_discrepancy[f] = 0.0 - - return bbox_centroids, arith_centroids, centroid_discrepancy - - -def _compute_frame_velocities( - bbox_centroids: np.ndarray, - tier1_valid: np.ndarray, -) -> tuple[np.ndarray, np.ndarray]: - """Compute frame-to-frame centroid velocities and speeds. - - A velocity is valid only when both adjacent frames are tier-1 valid. - - Parameters - ---------- - bbox_centroids : np.ndarray - Array of shape (n_frames, 2) with bounding-box centroids. - tier1_valid : np.ndarray - Boolean array of shape (n_frames,) indicating tier-1 valid frames. - - Returns - ------- - velocities : np.ndarray - Array of shape (n_frames - 1, 2) with velocity vectors. - Invalid velocities are NaN. - speeds : np.ndarray - Array of shape (n_frames - 1,) with speed scalars. - Invalid speeds are NaN. - - """ - velocities = np.diff(bbox_centroids, axis=0) # (n_frames - 1, 2) - - # Velocity valid only if both adjacent frames are tier-1 valid - speed_valid = tier1_valid[:-1] & tier1_valid[1:] - - # Mask invalid velocities - velocities[~speed_valid] = np.nan - - speeds = np.linalg.norm(velocities, axis=1) - - return velocities, speeds - - -def _compute_sliding_window_medians( - speeds: np.ndarray, - window_len: int, - stride: int, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Compute median speeds for sliding windows. - - A window is valid only when every speed sample in that window is valid - (non-NaN), ensuring strict NaN-free content. - - Parameters - ---------- - speeds : np.ndarray - Array of shape (n_speed_samples,) with speed values. - window_len : int - Number of speed samples per sliding window. - stride : int - Step size between consecutive window start positions. - - Returns - ------- - window_starts : np.ndarray - Array of window start indices (0-indexed). - window_medians : np.ndarray - Median speed for each window. NaN for invalid windows. - window_all_valid : np.ndarray - Boolean array indicating which windows are fully valid. - - """ - num_speed = len(speeds) - window_starts = np.arange(0, num_speed - window_len + 1, stride) - num_windows = len(window_starts) - - window_medians = np.full(num_windows, np.nan) - window_all_valid = np.zeros(num_windows, dtype=bool) - - for k in range(num_windows): - s = window_starts[k] - e = s + window_len - w = speeds[s:e] - - # Window valid only if all samples are non-NaN - if np.all(~np.isnan(w)): - window_all_valid[k] = True - window_medians[k] = np.median(w) - - return window_starts, window_medians, window_all_valid - - -def _detect_high_motion_windows( - window_medians: np.ndarray, - window_all_valid: np.ndarray, - pct_thresh: float, -) -> np.ndarray: - """Identify high-motion windows based on percentile threshold. - - Parameters - ---------- - window_medians : np.ndarray - Median speed for each window. - window_all_valid : np.ndarray - Boolean array indicating which windows are fully valid. - pct_thresh : float - Percentile threshold (0-100) for high-motion classification. - - Returns - ------- - high_motion : np.ndarray - Boolean array indicating high-motion windows. - - """ - valid_medians = window_medians[window_all_valid] - if len(valid_medians) == 0: - return np.zeros(len(window_medians), dtype=bool) - - thresh = np.percentile(valid_medians, pct_thresh) - high_motion = window_all_valid & (window_medians >= thresh) - - return high_motion - - -def _detect_runs( - high_motion: np.ndarray, - min_run_len: int, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Detect runs of consecutive high-motion windows. - - A run is a maximal sequence of consecutively indexed qualifying windows. - - Parameters - ---------- - high_motion : np.ndarray - Boolean array indicating high-motion windows. - min_run_len : int - Minimum number of consecutive qualifying windows for a valid run. - - Returns - ------- - run_starts : np.ndarray - Start indices of valid runs. - run_ends : np.ndarray - End indices (inclusive) of valid runs. - run_lengths : np.ndarray - Length of each valid run. - - """ - padded = np.concatenate([[False], high_motion, [False]]) - d = np.diff(padded.astype(int)) - - # Find run boundaries - run_starts_all = np.where(d == 1)[0] - run_ends_all = np.where(d == -1)[0] - 1 - run_lengths_all = run_ends_all - run_starts_all + 1 - - # Filter by minimum run length - valid_mask = run_lengths_all >= min_run_len - run_starts = run_starts_all[valid_mask] - run_ends = run_ends_all[valid_mask] - run_lengths = run_lengths_all[valid_mask] - - return run_starts, run_ends, run_lengths - - -def _convert_runs_to_segments( - run_starts: np.ndarray, - run_ends: np.ndarray, - window_starts: np.ndarray, - window_len: int, -) -> np.ndarray: - """Convert window runs to frame segments. - - Each run is converted to a frame interval spanning from the start frame - of the first window to the end frame of the last window. - - Parameters - ---------- - run_starts : np.ndarray - Start indices of valid runs (indices into window arrays). - run_ends : np.ndarray - End indices (inclusive) of valid runs. - window_starts : np.ndarray - Start frame indices for each window. - window_len : int - Length of each window in frames. - - Returns - ------- - segments_raw : np.ndarray - Array of shape (n_runs, 2) with [frame_start, frame_end] for each run. - - """ - n_runs = len(run_starts) - segments_raw = np.zeros((n_runs, 2), dtype=int) - - for j in range(n_runs): - s_idx = run_starts[j] - e_idx = run_ends[j] - frame_start = window_starts[s_idx] - frame_end = window_starts[e_idx] + window_len - segments_raw[j] = [frame_start, frame_end] - - return segments_raw - - -def _merge_segments(segments_raw: np.ndarray) -> np.ndarray: - """Merge overlapping or abutting frame segments. - - Segments are first sorted by start frame, then merged if they overlap - or abut (next start <= current end + 1). - - Parameters - ---------- - segments_raw : np.ndarray - Array of shape (n_segments, 2) with [frame_start, frame_end]. - - Returns - ------- - segments : np.ndarray - Array of merged non-overlapping segments. - - """ - if len(segments_raw) == 0: - return segments_raw - - # Sort by start frame - sorted_idx = np.argsort(segments_raw[:, 0]) - segments_sorted = segments_raw[sorted_idx] - - merged = [segments_sorted[0].tolist()] - - for j in range(1, len(segments_sorted)): - next_seg = segments_sorted[j] - curr_seg = merged[-1] - - # Merge if overlapping or abutting - if next_seg[0] <= curr_seg[1] + 1: - merged[-1][1] = max(curr_seg[1], next_seg[1]) - else: - merged.append(next_seg.tolist()) - - return np.array(merged, dtype=int) - - -def _filter_segments_tier2( - segments: np.ndarray, - tier2_valid: np.ndarray, -) -> tuple[np.ndarray, np.ndarray]: - """Filter segment frames to retain only tier-2 valid frames. - - Parameters - ---------- - segments : np.ndarray - Array of shape (n_segments, 2) with [frame_start, frame_end]. - tier2_valid : np.ndarray - Boolean array of shape (n_frames,) indicating tier-2 valid frames. - - Returns - ------- - selected_frames : np.ndarray - Array of tier-2 valid frame indices within segments. - selected_seg_id : np.ndarray - Segment ID (0-indexed) for each selected frame. - - """ - # Collect all unique frames from all segments - all_segment_frames: list[int] = [] - for k in range(len(segments)): - frame_start, frame_end = segments[k] - seg_frames = np.arange(frame_start, frame_end + 1) - all_segment_frames.extend(seg_frames) - - segment_frames_all = np.unique(all_segment_frames) - - tier2_mask = tier2_valid[segment_frames_all] - selected_frames = segment_frames_all[tier2_mask] - - # Assign each selected frame to its segment - num_selected = len(selected_frames) - selected_seg_id = np.zeros(num_selected, dtype=int) - - for j in range(num_selected): - f = selected_frames[j] - for k in range(len(segments)): - if segments[k, 0] <= f <= segments[k, 1]: - selected_seg_id[j] = k - break - - return selected_frames, selected_seg_id - - -def _build_centered_skeletons( - keypoints: np.ndarray, - selected_frames: np.ndarray, -) -> tuple[np.ndarray, np.ndarray]: - """Build centroid-centered skeletons for selected frames. - - Uses bounding-box centroid for centering, consistent with the - segmentation centroid. - - Parameters - ---------- - keypoints : np.ndarray - Keypoint positions with shape (n_frames, n_keypoints, 2). - selected_frames : np.ndarray - Array of selected frame indices. - - Returns - ------- - selected_centroids : np.ndarray - Array of shape (num_selected, 2) with bounding-box centroids. - centered_skeletons : np.ndarray - Array of shape (num_selected, n_keypoints, 2) with - centroid-centered skeleton coordinates. - - """ - num_selected = len(selected_frames) - n_keypoints = keypoints.shape[1] - - selected_centroids = np.zeros((num_selected, 2)) - centered_skeletons = np.zeros((num_selected, n_keypoints, 2)) - - for j in range(num_selected): - f = selected_frames[j] - kp_f = keypoints[f] # (n_keypoints, 2) - all present for tier-2 - - # Bounding-box centroid - bbox_min = np.min(kp_f, axis=0) - bbox_max = np.max(kp_f, axis=0) - centroid_f = (bbox_min + bbox_max) / 2 - - selected_centroids[j] = centroid_f - centered_skeletons[j] = kp_f - centroid_f - - return selected_centroids, centered_skeletons - - -def _compute_pairwise_rmsd(centered_skeletons: np.ndarray) -> np.ndarray: - """Compute pairwise RMSD between all centered skeletons. - - RMSD is computed as the square root of the mean of squared entry-wise - differences between flattened skeleton vectors. - - Parameters - ---------- - centered_skeletons : np.ndarray - Array of shape (num_selected, n_keypoints, 2). - - Returns - ------- - rmsd_matrix : np.ndarray - Symmetric matrix of shape (num_selected, num_selected) with - pairwise RMSD values. Diagonal is zero. - - """ - num_selected = len(centered_skeletons) - - skel_flat = centered_skeletons.reshape(num_selected, -1) - - rmsd_matrix = np.zeros((num_selected, num_selected)) - - for i in range(num_selected): - for j in range(i + 1, num_selected): - d = skel_flat[i] - skel_flat[j] - rmsd_val = np.sqrt(np.mean(d**2)) - rmsd_matrix[i, j] = rmsd_val - rmsd_matrix[j, i] = rmsd_val - - return rmsd_matrix - - -def _compute_postural_variance_ratio( - rmsd_matrix: np.ndarray, - selected_seg_id: np.ndarray, -) -> tuple[float, np.ndarray, np.ndarray, bool]: - """Compute the between/within segment RMSD variance ratio. - - Parameters - ---------- - rmsd_matrix : np.ndarray - Pairwise RMSD matrix of shape (num_selected, num_selected). - selected_seg_id : np.ndarray - Segment ID for each selected frame. - - Returns - ------- - var_ratio : float - Ratio of between-segment to within-segment RMSD variance. - Returns 0.0 if either distribution is empty or within variance is 0. - within_rmsds : np.ndarray - Array of within-segment RMSD values. - between_rmsds : np.ndarray - Array of between-segment RMSD values. - var_ratio_override : bool - True if variance ratio was set to 0 due to edge cases. - - """ - num_selected = len(selected_seg_id) - within_rmsds_list: list[float] = [] - between_rmsds_list: list[float] = [] - - for i in range(num_selected): - for j in range(i + 1, num_selected): - if selected_seg_id[i] == selected_seg_id[j]: - within_rmsds_list.append(rmsd_matrix[i, j]) - else: - between_rmsds_list.append(rmsd_matrix[i, j]) - - within_rmsds = np.array(within_rmsds_list) - between_rmsds = np.array(between_rmsds_list) - - # Compute variance ratio with edge case handling - var_ratio_override = False - if ( - len(within_rmsds) > 0 - and len(between_rmsds) > 0 - and np.var(within_rmsds) > 0 - ): - var_ratio = np.var(between_rmsds) / np.var(within_rmsds) - else: - var_ratio = 0.0 - var_ratio_override = True - - return var_ratio, within_rmsds, between_rmsds, var_ratio_override - - -def _kmedoids( - data: np.ndarray, - k: int, - max_iter: int = 100, - n_init: int = 5, - random_state: int | None = None, -) -> tuple[np.ndarray, np.ndarray, float]: - """Perform k-medoids clustering. - - Parameters - ---------- - data : np.ndarray - Array of shape (n_samples, n_features). - k : int - Number of clusters. - max_iter : int, default=100 - Maximum number of iterations. - n_init : int, default=5 - Number of random initializations. - random_state : int, optional - Random seed for reproducibility. - - Returns - ------- - labels : np.ndarray - Cluster labels for each sample (0-indexed). - medoid_indices : np.ndarray - Indices of medoid samples. - inertia : float - Sum of distances from samples to their medoids. - - """ - from scipy.spatial.distance import cdist - - rng = np.random.default_rng(random_state) - n_samples = len(data) - - dist_matrix = cdist(data, data, metric="euclidean") - - best_labels: np.ndarray | None = None - best_medoids: np.ndarray | None = None - best_inertia = np.inf - - for _ in range(n_init): - medoids = rng.choice(n_samples, size=k, replace=False) - - for _ in range(max_iter): - distances_to_medoids = dist_matrix[:, medoids] - labels = np.argmin(distances_to_medoids, axis=1) - - # Update medoids - new_medoids = np.zeros(k, dtype=int) - for cluster in range(k): - cluster_mask = labels == cluster - if not np.any(cluster_mask): - # Empty cluster - keep old medoid - new_medoids[cluster] = medoids[cluster] - continue - - cluster_indices = np.where(cluster_mask)[0] - # Find point that minimizes sum of distances within cluster - cluster_dists = dist_matrix[ - np.ix_(cluster_indices, cluster_indices) - ] - total_dists = np.sum(cluster_dists, axis=1) - best_idx = np.argmin(total_dists) - new_medoids[cluster] = cluster_indices[best_idx] - - if np.array_equal(np.sort(medoids), np.sort(new_medoids)): - break - medoids = new_medoids - - distances_to_medoids = dist_matrix[:, medoids] - labels = np.argmin(distances_to_medoids, axis=1) - inertia = np.sum(distances_to_medoids[np.arange(n_samples), labels]) - - if inertia < best_inertia: - best_inertia = inertia - best_labels = labels.copy() - best_medoids = medoids.copy() - - # These are guaranteed to be set after at least one iteration - assert best_labels is not None and best_medoids is not None - return best_labels, best_medoids, best_inertia - - -def _silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: - """Compute mean silhouette score. - - Parameters - ---------- - data : np.ndarray - Array of shape (n_samples, n_features). - labels : np.ndarray - Cluster labels for each sample. - - Returns - ------- - score : float - Mean silhouette score across all samples. - Returns 0.0 if clustering is degenerate. - - """ - from scipy.spatial.distance import cdist - - n_samples = len(data) - unique_labels = np.unique(labels) - n_clusters = len(unique_labels) - - if n_clusters <= 1 or n_clusters >= n_samples: - return 0.0 - - dist_matrix = cdist(data, data, metric="euclidean") - - silhouette_vals = np.zeros(n_samples) - - for i in range(n_samples): - own_cluster = labels[i] - own_mask = labels == own_cluster - - # a(i) = mean distance to points in same cluster - if np.sum(own_mask) > 1: - a_i = np.mean( - dist_matrix[i, own_mask & (np.arange(n_samples) != i)] - ) - else: - a_i = 0.0 - - # b(i) = min over other clusters of mean distance to that cluster - b_i = np.inf - for cluster in unique_labels: - if cluster == own_cluster: - continue - cluster_mask = labels == cluster - if np.any(cluster_mask): - mean_dist = np.mean(dist_matrix[i, cluster_mask]) - b_i = min(b_i, mean_dist) - - if b_i == np.inf: - silhouette_vals[i] = 0.0 - else: - silhouette_vals[i] = ( - (b_i - a_i) / max(a_i, b_i) if max(a_i, b_i) > 0 else 0.0 - ) - - return float(np.mean(silhouette_vals)) - - -def _perform_postural_clustering( - centered_skeletons: np.ndarray, - max_clusters: int, - min_silhouette: float = 0.2, -) -> tuple[np.ndarray, int, int, float, list[tuple[int, float]]]: - """Perform postural clustering using k-medoids with silhouette selection. - - Parameters - ---------- - centered_skeletons : np.ndarray - Array of shape (num_selected, n_keypoints, 2). - max_clusters : int - Maximum number of clusters to evaluate. - min_silhouette : float, default=0.2 - Minimum silhouette score to accept clustering. - - Returns - ------- - cluster_labels : np.ndarray - Cluster labels for each frame (0-indexed). - num_clusters : int - Number of clusters (1 if clustering not accepted). - primary_cluster : int - Index of largest cluster (0-indexed). - best_silhouette : float - Best silhouette score achieved. - silhouette_scores : list of (k, score) - Silhouette scores for each k evaluated. - - """ - num_selected = len(centered_skeletons) - skel_flat = centered_skeletons.reshape(num_selected, -1) - - best_k = 1 - best_sil = -np.inf - silhouette_scores = [] - - # Evaluate k from 2 to max_clusters (capped at num_selected // 2) - max_k = min(max_clusters, num_selected // 2) - - for k in range(2, max_k + 1): - try: - labels, _, _ = _kmedoids(skel_flat, k, n_init=5) - sil = _silhouette_score(skel_flat, labels) - silhouette_scores.append((k, sil)) - - if sil > best_sil: - best_sil = sil - best_k = k - except Exception: - silhouette_scores.append((k, np.nan)) - - # Accept clustering only if best_sil > min_silhouette - if best_k > 1 and best_sil > min_silhouette: - # Re-run with more initializations for final result - cluster_labels, _, _ = _kmedoids(skel_flat, best_k, n_init=10) - num_clusters = best_k - - # Primary cluster = largest - cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) - primary_cluster = int(np.argmax(cluster_counts)) - else: - cluster_labels = np.zeros(num_selected, dtype=int) - num_clusters = 1 - primary_cluster = 0 - - return ( - cluster_labels, - num_clusters, - primary_cluster, - best_sil, - silhouette_scores, - ) - - -def _compute_cluster_velocities( - selected_frames: np.ndarray, - selected_seg_id: np.ndarray, - cluster_mask: np.ndarray, - segments: np.ndarray, - bbox_centroids: np.ndarray, -) -> np.ndarray: - """Compute velocities between adjacent consecutive frames. - - Only considers frames in the same segment and cluster. Frame pairs - where both frames are consecutive (frame[i] == frame[i-1] + 1), - in the same segment, and in the same cluster contribute a velocity - vector. - This prevents spanning temporal gaps or mixing postures across clusters. - - Returns - ------- - np.ndarray - Array of shape (n_velocities, 2). Empty (0, 2) if no valid pairs. - - """ - frames_c = selected_frames[cluster_mask] - seg_ids_c = selected_seg_id[cluster_mask] - velocities_list: list[np.ndarray] = [] - - for seg_k in range(len(segments)): - seg_mask = seg_ids_c == seg_k - seg_frames = np.sort(frames_c[seg_mask]) - for fi in range(1, len(seg_frames)): - if seg_frames[fi] != seg_frames[fi - 1] + 1: - continue - curr_frame = seg_frames[fi] - prev_frame = seg_frames[fi - 1] - v = bbox_centroids[curr_frame] - bbox_centroids[prev_frame] - if np.all(~np.isnan(v)): - velocities_list.append(v) - - return np.array(velocities_list) if velocities_list else np.zeros((0, 2)) - - -def _infer_anterior_from_velocities( - velocities: np.ndarray, - PC1: np.ndarray, -) -> dict: - """Infer anterior direction from velocity projections onto PC1. - - Uses strict majority vote on PC1 projection signs: anterior = +PC1 - if n_positive > n_negative, else −PC1 (ties default to −PC1). - - Also computes circular statistics on velocity angles: - - resultant_length R = √(C² + S²) where C = mean(cos θ), S = mean(sin θ) - - vote_margin M = |n₊ − n₋| / (n₊ + n₋) - - Returns dict with resultant_length, circ_mean_dir, vel_projs_pc1, - num_positive, num_negative, vote_margin, anterior_sign. - """ - result: dict = { - "resultant_length": 0.0, - "circ_mean_dir": np.nan, - "vel_projs_pc1": np.array([]), - "num_positive": 0, - "num_negative": 0, - "vote_margin": 0.0, - "anterior_sign": -1, - } - if len(velocities) == 0: - return result - - vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0]) - circ_C = np.mean(np.cos(vel_angles)) - circ_S = np.mean(np.sin(vel_angles)) - result["resultant_length"] = np.sqrt(circ_C**2 + circ_S**2) - result["circ_mean_dir"] = np.arctan2(circ_S, circ_C) - - vel_projs = velocities @ PC1 - num_pos = int(np.sum(vel_projs > 0)) - num_neg = int(np.sum(vel_projs < 0)) - result["vel_projs_pc1"] = vel_projs - result["num_positive"] = num_pos - result["num_negative"] = num_neg - result["vote_margin"] = abs(num_pos - num_neg) / max(num_pos + num_neg, 1) - result["anterior_sign"] = +1 if num_pos > num_neg else -1 - return result - - -def _compute_cluster_pca_and_anterior( - centered_skeletons: np.ndarray, - cluster_mask: np.ndarray, - selected_frames: np.ndarray, - selected_seg_id: np.ndarray, - segments: np.ndarray, - bbox_centroids: np.ndarray, -) -> dict: - """Compute SVD-based PCA and velocity-based anterior inference. - - Performs inference for one cluster. - - Performs SVD on the cluster's average centered skeleton to extract PC1/PC2, - applies the geometric sign convention, then infers the anterior direction - via velocity voting on centroid displacements projected onto PC1. - - Returns - ------- - dict - Keys: valid, n_frames, avg_skeleton, valid_shape_rows, - PC1, PC2, anterior_sign, vote_margin, resultant_length, - circ_mean_dir, velocities, vel_projs_pc1, and others. - - """ - n_keypoints = centered_skeletons.shape[1] - n_c = int(np.sum(cluster_mask)) - - result: dict = { - "valid": False, - "n_frames": n_c, - "avg_skeleton": np.full((n_keypoints, 2), np.nan), - "valid_shape_rows": np.zeros(n_keypoints, dtype=bool), - "PC1": np.array([1.0, 0.0]), - "PC2": np.array([0.0, 1.0]), - "proj_pc1": np.full(n_keypoints, np.nan), - "proj_pc2": np.full(n_keypoints, np.nan), - "anterior_sign": -1, - "num_positive": 0, - "num_negative": 0, - "vote_margin": 0.0, - "resultant_length": 0.0, - "circ_mean_dir": np.nan, - "velocities": np.zeros((0, 2)), - "vel_projs_pc1": np.array([]), - } - - if n_c == 0: - return result - - skels_c = centered_skeletons[cluster_mask] - avg_skel_c = np.mean(skels_c, axis=0) - valid_shape_rows = ~np.any(np.isnan(avg_skel_c), axis=1) - - if np.sum(valid_shape_rows) < 2: - return result - - result["avg_skeleton"] = avg_skel_c - result["valid_shape_rows"] = valid_shape_rows - - # SVD on valid (non-NaN) rows of the average centered skeleton - D_valid = avg_skel_c[valid_shape_rows] - _U, _S, Vt = np.linalg.svd(D_valid, full_matrices=False) - PC1 = Vt[0] - PC2 = Vt[1] if len(Vt) > 1 else np.array([0.0, 1.0]) - - # Geometric sign convention (reproducible across runs, decoupled - # from anatomical AP assignment which is determined by velocity voting): - # PC1 flipped so y-component >= 0 - # PC2 flipped so x-component >= 0 - - if PC1[1] < 0: - PC1 = -PC1 - if PC2[0] < 0: - PC2 = -PC2 - - result["PC1"] = PC1 - result["PC2"] = PC2 - - proj_pc1 = np.full(n_keypoints, np.nan) - proj_pc2 = np.full(n_keypoints, np.nan) - proj_pc1[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC1 - proj_pc2[valid_shape_rows] = avg_skel_c[valid_shape_rows] @ PC2 - result["proj_pc1"] = proj_pc1 - result["proj_pc2"] = proj_pc2 - - velocities = _compute_cluster_velocities( - selected_frames, - selected_seg_id, - cluster_mask, - segments, - bbox_centroids, - ) - result["velocities"] = velocities - result.update(_infer_anterior_from_velocities(velocities, PC1)) - result["valid"] = True - return result - - -def _compute_node_projections( - report: _APNodePairReport, - avg_skeleton: np.ndarray, - PC1: np.ndarray, - anterior_sign: int, - valid_shape_rows: np.ndarray, - from_node: int, - to_node: int, -) -> None: - """Compute raw PC1, AP-oriented, and lateral projections. - - Computes projections for all valid keypoints. - - Populates the report's coordinate arrays and determines: - - pc1_coords: raw projection onto PC1 (sign-convention only) - - ap_coords: projection onto anterior_sign × PC1 (positive = more - anterior) - - lateral_offsets: unsigned distance from the AP axis - - midpoint_pc1: average of min/max PC1 projections (AP reference point) - - input_pair_order_matches_inference: True if from_node's AP coord < - to_node's - """ - pc1 = PC1 / np.linalg.norm(PC1) - # AP unit vector: anterior_sign * PC1, so positive projection = anterior - e_ap = anterior_sign * pc1 - # Lateral unit vector: 90° CCW rotation of e_ap - e_lat = np.array([-e_ap[1], e_ap[0]]) - - D_valid = avg_skeleton[valid_shape_rows] - report.pc1_coords[valid_shape_rows] = D_valid @ pc1 - report.ap_coords[valid_shape_rows] = D_valid @ e_ap - report.lateral_offsets[valid_shape_rows] = np.abs(D_valid @ e_lat) - - if valid_shape_rows[from_node] and valid_shape_rows[to_node]: - report.input_pair_order_matches_inference = ( - report.ap_coords[from_node] < report.ap_coords[to_node] - ) - - proj_pc1_valid = report.pc1_coords[valid_shape_rows] - report.pc1_min = float(np.min(proj_pc1_valid)) - report.pc1_max = float(np.max(proj_pc1_valid)) - report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 - - -def _apply_lateral_filter( - report: _APNodePairReport, - valid_idx: np.ndarray, - lateral_thresh: float, -) -> np.ndarray | None: - """Step 1: Filter keypoints by normalized lateral offset. - - Returns sorted candidate node indices, or None on failure. - """ - d_valid = report.lateral_offsets[valid_idx] - d_min = float(np.min(d_valid)) - d_max = float(np.max(d_valid)) - report.lateral_offset_min = d_min - report.lateral_offset_max = d_max - - if d_max > d_min: - d_norm = (d_valid - d_min) / (d_max - d_min) - report.lateral_offsets_norm[valid_idx] = d_norm - keep_mask = d_norm <= lateral_thresh - else: - report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) - keep_mask = np.ones(len(d_valid), dtype=bool) - - candidate_idx = np.where(keep_mask)[0] - C = valid_idx[candidate_idx] - sorted_order = np.argsort(d_valid[candidate_idx]) - C = C[sorted_order] - report.sorted_candidate_nodes = C.copy() - - if len(C) < 2: - report.failure_step = "Step 1: lateral alignment filter" - report.failure_reason = ( - "Fewer than 2 candidates remained after filtering." - ) - return None - return C - - -def _find_opposite_side_pairs( - report: _APNodePairReport, - C: np.ndarray, - from_node: int, - to_node: int, - valid_shape_rows: np.ndarray, -) -> tuple[np.ndarray, np.ndarray] | None: - """Step 2: Find candidate pairs on opposite sides of the AP midpoint. - - Returns (pairs, seps) arrays, or None on failure. - """ - m = report.midpoint_pc1 - report.input_pair_in_candidates = (from_node in C) and (to_node in C) - - pairs_list: list[list[int]] = [] - seps_list: list[float] = [] - for ii in range(len(C)): - for jj in range(ii + 1, len(C)): - i, j = C[ii], C[jj] - if (report.pc1_coords[i] - m) * (report.pc1_coords[j] - m) < 0: - pairs_list.append([i, j]) - seps_list.append( - abs(report.ap_coords[i] - report.ap_coords[j]) - ) - - pairs = ( - np.array(pairs_list, dtype=int) - if pairs_list - else np.zeros((0, 2), dtype=int) - ) - seps = np.array(seps_list) if seps_list else np.array([]) - report.valid_pairs = pairs - report.valid_pairs_internode_dist = seps - - if valid_shape_rows[from_node] and valid_shape_rows[to_node]: - report.input_pair_opposite_sides = ( - (report.pc1_coords[from_node] - m) - * (report.pc1_coords[to_node] - m) - ) < 0 - report.input_pair_separation_abs = abs( - report.ap_coords[from_node] - report.ap_coords[to_node] - ) - - if len(pairs) == 0: - report.failure_step = "Step 2: opposite-sides constraint" - report.failure_reason = ( - "No candidate pair lies on opposite sides of the midpoint." - ) - return None - return pairs, seps - - -def _order_pair_by_ap( - pair: np.ndarray, - ap_coords: np.ndarray, -) -> np.ndarray: - """Order a node pair so element 0 is posterior (lower AP coord). - - This ensures that suggested pairs always encode the - posterior→anterior direction, matching the convention used by - ``body_axis_keypoints=(from_node, to_node)`` where from_node is - posterior and to_node is anterior. - - Parameters - ---------- - pair : np.ndarray - Two-element array of node indices. - ap_coords : np.ndarray - AP coordinates for all keypoints (anterior_sign already applied). - - Returns - ------- - np.ndarray - The same two indices, ordered so that - ``ap_coords[result[0]] <= ap_coords[result[1]]``. - - """ - i, j = pair - if ap_coords[i] <= ap_coords[j]: - return np.array([i, j], dtype=int) - return np.array([j, i], dtype=int) - - -def _classify_distal_proximal( - report: _APNodePairReport, - pairs: np.ndarray, - seps: np.ndarray, - valid_shape_rows: np.ndarray, - edge_thresh: float, -) -> np.ndarray: - """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" - m = report.midpoint_pc1 - midline_dist = np.abs(report.pc1_coords - m) - d_max_midline = float(np.nanmax(midline_dist[valid_shape_rows])) - report.midline_dist_max = d_max_midline - - if d_max_midline > 0: - report.midline_dist_norm = midline_dist / d_max_midline - else: - report.midline_dist_norm = np.zeros(len(report.pc1_coords)) - - pair_is_distal = np.zeros(len(pairs), dtype=bool) - for k in range(len(pairs)): - i, j = pairs[k] - pair_is_distal[k] = ( - min(report.midline_dist_norm[i], report.midline_dist_norm[j]) - >= edge_thresh - ) - - report.distal_pairs = pairs[pair_is_distal] - report.proximal_pairs = pairs[~pair_is_distal] - - if len(seps) > 0: - idx_max = int(np.argmax(seps)) - report.max_separation_nodes = _order_pair_by_ap( - pairs[idx_max], report.ap_coords - ) - report.max_separation = seps[idx_max] - - if np.any(pair_is_distal): - distal_seps = seps[pair_is_distal] - distal_pairs_only = pairs[pair_is_distal] - idx_max_distal = int(np.argmax(distal_seps)) - report.max_separation_distal_nodes = _order_pair_by_ap( - distal_pairs_only[idx_max_distal], report.ap_coords - ) - report.max_separation_distal = distal_seps[idx_max_distal] - - return pair_is_distal - - -def _check_input_pair_in_valid( - report: _APNodePairReport, - pairs: np.ndarray, - seps: np.ndarray, - pair_is_distal: np.ndarray, - from_node: int, - to_node: int, -) -> tuple[bool, int]: - """Check whether input pair is among valid pairs. Returns (found, idx).""" - input_pair_sorted = tuple(sorted([from_node, to_node])) - input_in_valid = False - input_idx = -1 - - for k in range(len(pairs)): - if tuple(sorted(pairs[k])) == input_pair_sorted: - input_in_valid = True - input_idx = k - break - - if input_in_valid: - report.input_pair_is_distal = pair_is_distal[input_idx] - rank_order = np.argsort(seps)[::-1] - report.input_pair_rank = ( - int(np.where(rank_order == input_idx)[0][0]) + 1 - ) - return input_in_valid, input_idx - - -def _evaluate_ap_node_pair( - avg_skeleton: np.ndarray, - PC1: np.ndarray, - anterior_sign: int, - valid_shape_rows: np.ndarray, - from_node: int, - to_node: int, - config: _ValidateAPConfig, -) -> _APNodePairReport: - """Evaluate an AP node pair through the 3-step filter cascade. - - Parameters - ---------- - avg_skeleton : np.ndarray - Average centered skeleton of shape (n_keypoints, 2). - PC1 : np.ndarray - First principal component vector of shape (2,). - anterior_sign : int - Inferred anterior direction (+1 or -1 relative to PC1). - valid_shape_rows : np.ndarray - Boolean array indicating valid (non-NaN) keypoints. - from_node : int - Index of the input from_node (body_axis_keypoints origin, - claimed posterior). 0-indexed. - to_node : int - Index of the input to_node (body_axis_keypoints target, - claimed anterior). 0-indexed. - config : _ValidateAPConfig - Configuration with ``lateral_thresh`` and ``edge_thresh``. - - Returns - ------- - _APNodePairReport - Complete evaluation report. - - """ - n_keypoints = len(avg_skeleton) - report = _APNodePairReport() - report.pc1_coords = np.full(n_keypoints, np.nan) - report.ap_coords = np.full(n_keypoints, np.nan) - report.lateral_offsets = np.full(n_keypoints, np.nan) - report.lateral_offsets_norm = np.full(n_keypoints, np.nan) - report.midline_dist_norm = np.full(n_keypoints, np.nan) - - for node, label in [(from_node, "from_node"), (to_node, "to_node")]: - if node < 0 or node >= n_keypoints: - report.failure_step = "Input validation" - report.failure_reason = ( - f"{label} must be a valid index in 0..{n_keypoints - 1}." - ) - return report - - valid_idx = np.where(valid_shape_rows)[0] - if len(valid_idx) < 2: - report.failure_step = "Step 1: lateral alignment filter" - report.failure_reason = "Fewer than 2 valid nodes are available." - return report - - _compute_node_projections( - report, - avg_skeleton, - PC1, - anterior_sign, - valid_shape_rows, - from_node, - to_node, - ) - - C = _apply_lateral_filter(report, valid_idx, config.lateral_thresh) - if C is None: - return report - - step2 = _find_opposite_side_pairs( - report, - C, - from_node, - to_node, - valid_shape_rows, - ) - if step2 is None: - return report - pairs, seps = step2 - - pair_is_distal = _classify_distal_proximal( - report, - pairs, - seps, - valid_shape_rows, - config.edge_thresh, - ) - - input_in_valid, input_idx = _check_input_pair_in_valid( - report, - pairs, - seps, - pair_is_distal, - from_node, - to_node, - ) - - report = _assign_scenario( - report, pairs, seps, pair_is_distal, input_in_valid, input_idx - ) - report.success = True - return report - - -def _assign_single_pair_scenario( - report: _APNodePairReport, - pairs: np.ndarray, - pair_is_distal: np.ndarray, - input_in_valid: bool, -) -> _APNodePairReport: - """Assign scenario when exactly one valid pair exists (scenarios 1-4).""" - if input_in_valid: - if pair_is_distal[0]: - report.scenario = 1 - report.outcome = "accept" - else: - report.scenario = 2 - report.outcome = "warn" - report.warning_message = "Input pair has proximal node(s)." - elif pair_is_distal[0]: - report.scenario = 3 - report.outcome = "warn" - report.warning_message = ( - f"Input invalid. Suggest pair [{pairs[0, 0]}, {pairs[0, 1]}]." - ) - else: - report.scenario = 4 - report.outcome = "warn" - report.warning_message = ( - f"Input invalid. Only option " - f"[{pairs[0, 0]}, {pairs[0, 1]}] has proximal node(s)." - ) - return report - - -def _assign_multi_input_distal_scenario( - report: _APNodePairReport, - pairs: np.ndarray, - input_idx: int, -) -> _APNodePairReport: - """Assign scenario for distal input in multi-pair case (5, 6, 7).""" - input_pair_sorted = tuple( - sorted([pairs[input_idx, 0], pairs[input_idx, 1]]) - ) - max_distal_sorted = ( - tuple(sorted(report.max_separation_distal_nodes)) - if len(report.max_separation_distal_nodes) > 0 - else () - ) - - if report.input_pair_rank == 1: - report.scenario = 5 - report.outcome = "accept" - elif input_pair_sorted == max_distal_sorted: - report.scenario = 7 - report.outcome = "accept" - else: - report.scenario = 6 - report.outcome = "warn" - d = report.max_separation_distal_nodes - report.warning_message = ( - f"Distal pair with greater separation exists: [{d[0]}, {d[1]}]." - ) - return report - - -def _assign_multi_input_proximal_scenario( - report: _APNodePairReport, - pair_is_distal: np.ndarray, -) -> _APNodePairReport: - """Assign scenario for proximal input in multi-pair case (8-11).""" - has_distal = np.any(pair_is_distal) - is_max_sep = report.input_pair_rank == 1 - - if is_max_sep and has_distal: - report.scenario = 8 - d = report.max_separation_distal_nodes - report.warning_message = ( - f"Input has proximal node(s). " - f"Distal alternative: [{d[0]}, {d[1]}]." - ) - elif is_max_sep: - report.scenario = 9 - report.warning_message = ( - "Input has proximal node(s). All pairs have proximal node(s)." - ) - elif has_distal: - report.scenario = 10 - d = report.max_separation_distal_nodes - report.warning_message = ( - f"Input has proximal node(s). " - f"Distal pair with greater separation: [{d[0]}, {d[1]}]." - ) - else: - report.scenario = 11 - report.warning_message = ( - "Input has proximal node(s). All pairs have proximal node(s)." - ) - - report.outcome = "warn" - return report - - -def _assign_multi_input_invalid_scenario( - report: _APNodePairReport, - pair_is_distal: np.ndarray, -) -> _APNodePairReport: - """Assign scenario when input not in valid pairs (12-13).""" - has_distal = np.any(pair_is_distal) - report.outcome = "warn" - - if has_distal: - report.scenario = 12 - d = report.max_separation_distal_nodes - report.warning_message = ( - f"Input invalid. Suggest max separation distal pair: " - f"[{d[0]}, {d[1]}]." - ) - else: - report.scenario = 13 - m = report.max_separation_nodes - report.warning_message = ( - f"Input invalid. All pairs have proximal node(s). " - f"Max separation: [{m[0]}, {m[1]}]." - ) - return report - - -def _assign_scenario( - report: _APNodePairReport, - pairs: np.ndarray, - seps: np.ndarray, - pair_is_distal: np.ndarray, - input_in_valid: bool, - input_idx: int, -) -> _APNodePairReport: - """Assign one of 13 mutually exclusive scenarios. - - Parameters - ---------- - report : _APNodePairReport - The report to update with scenario information. - pairs : np.ndarray - Valid pairs array of shape (n_pairs, 2). - seps : np.ndarray - Internode separations for each pair. - pair_is_distal : np.ndarray - Boolean array indicating distal pairs. - input_in_valid : bool - Whether input pair is among valid pairs. - input_idx : int - Index of input pair in valid pairs (-1 if not present). - - Returns - ------- - _APNodePairReport - Updated report with scenario, outcome, and warning_message. - - """ - if len(pairs) == 1: - return _assign_single_pair_scenario( - report, - pairs, - pair_is_distal, - input_in_valid, - ) - - if not input_in_valid: - return _assign_multi_input_invalid_scenario(report, pair_is_distal) - - if report.input_pair_is_distal: - return _assign_multi_input_distal_scenario( - report, - pairs, - input_idx, - ) - - return _assign_multi_input_proximal_scenario(report, pair_is_distal) - - -# ── _validate_ap helper functions ──────────────────────────────────────── - - -def _resolve_node_index(node: Hashable, names: list) -> int: - """Resolve a node identifier to an integer index.""" - if isinstance(node, str): - if node in names: - return names.index(node) - raise ValueError(f"Keypoint '{node}' not found in {names}.") - if isinstance(node, int): - return node - return int(node) # type: ignore[call-overload] - - -def _prepare_validation_inputs( - data: xr.DataArray, - from_node: Hashable, - to_node: Hashable, -) -> tuple[np.ndarray, int, int, str, str, list[str], int]: - """Validate inputs and extract numpy arrays for AP validation. - - Returns - ------- - tuple - (keypoints, from_idx, to_idx, from_name, to_name, - keypoint_names, num_frames) - - Raises - ------ - TypeError - If data is not an xarray.DataArray. - ValueError - If dimensions or indices are invalid. - - """ - _validate_type_data_array(data) - - required_dims = {"time", "space", "keypoints"} - if not required_dims.issubset(set(data.dims)): - raise ValueError( - f"data must have dimensions {required_dims}, " - f"but has {set(data.dims)}." - ) - - if "individuals" in data.dims: - if data.sizes["individuals"] != 1: - raise ValueError( - "data must be for a single individual. " - "Use data.sel(individuals='name') to select one." - ) - data = data.squeeze("individuals", drop=True) - - if "keypoints" in data.coords: - keypoint_names = list(data.coords["keypoints"].values) - else: - keypoint_names = [f"node_{i}" for i in range(data.sizes["keypoints"])] - - n_keypoints = data.sizes["keypoints"] - from_idx = _resolve_node_index(from_node, keypoint_names) - to_idx = _resolve_node_index(to_node, keypoint_names) - - if from_idx < 0 or from_idx >= n_keypoints: - raise ValueError( - f"from_node index {from_idx} out of range [0, {n_keypoints - 1}]." - ) - if to_idx < 0 or to_idx >= n_keypoints: - raise ValueError( - f"to_node index {to_idx} out of range [0, {n_keypoints - 1}]." - ) - - data_xy = data.sel(space=["x", "y"]) - keypoints = data_xy.transpose("time", "keypoints", "space").values - - from_name = keypoint_names[from_idx] - to_name = keypoint_names[to_idx] - num_frames = keypoints.shape[0] - - return ( - keypoints, - from_idx, - to_idx, - from_name, - to_name, - keypoint_names, - num_frames, - ) - - -def _run_motion_segmentation( - keypoints: np.ndarray, - num_frames: int, - config: _ValidateAPConfig, - log_info, - log_warning, -) -> dict | None: - """Run tiered validity through segment detection. - - Returns a dict with tier1_valid, tier2_valid, bbox_centroids, - segments, or None on failure (error logged). - """ - tier1_valid, tier2_valid, _frac = _compute_tiered_validity( - keypoints, config.min_valid_frac - ) - num_tier1 = int(np.sum(tier1_valid)) - num_tier2 = int(np.sum(tier2_valid)) - - log_info("────────────────────────────────────────────────────────────") - log_info("Tiered Validity Report") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Tier 1 (>= %.0f%% keypoints): %d / %d frames (%.2f%%)", - config.min_valid_frac * 100, - num_tier1, - num_frames, - 100 * num_tier1 / num_frames, - ) - log_info( - "Tier 2 (100%% keypoints): %d / %d frames (%.2f%%)", - num_tier2, - num_frames, - 100 * num_tier2 / num_frames, - ) - - if num_tier1 < 2: - logger.error("Not enough tier-1 valid frames.") - return None - - bbox_centroids, _arith, centroid_disc = _compute_bbox_centroid( - keypoints, tier1_valid - ) - valid_disc = centroid_disc[tier1_valid & ~np.isnan(centroid_disc)] - if len(valid_disc) > 0: - log_info("") - log_info( - "────────────────────────────────────────────────────────────" - ) - log_info("Centroid Discrepancy Diagnostic") - log_info( - "────────────────────────────────────────────────────────────" - ) - log_info("BBox vs arithmetic centroid (normalized by bbox diagonal):") - log_info( - " Median: %.4f | Mean: %.4f | Max: %.4f", - np.median(valid_disc), - np.mean(valid_disc), - np.max(valid_disc), - ) - if np.median(valid_disc) > 0.05: - log_warning( - "Median discrepancy > 5%% - annotation density " - "is likely asymmetric." - ) - - segments = _detect_motion_segments( - bbox_centroids, tier1_valid, config, log_info, log_warning - ) - if segments is None: - return None - - return { - "tier1_valid": tier1_valid, - "tier2_valid": tier2_valid, - "bbox_centroids": bbox_centroids, - "segments": segments, - } - - -def _detect_motion_segments( - bbox_centroids: np.ndarray, - tier1_valid: np.ndarray, - config: _ValidateAPConfig, - log_info, - log_warning, -) -> np.ndarray | None: - """Detect high-motion segments from centroid velocities. - - Returns merged segments array, or None on failure. - """ - velocities, speeds = _compute_frame_velocities(bbox_centroids, tier1_valid) - num_speed = len(speeds) - - if num_speed < config.window_len: - logger.error( - "window_len=%d exceeds available speed samples=%d.", - config.window_len, - num_speed, - ) - return None - - window_starts, window_medians, window_all_valid = ( - _compute_sliding_window_medians( - speeds, config.window_len, config.stride - ) - ) - num_valid_windows = int(np.sum(window_all_valid)) - if num_valid_windows == 0: - logger.error("No fully valid sliding windows found.") - return None - - high_motion = _detect_high_motion_windows( - window_medians, window_all_valid, config.pct_thresh - ) - num_high_motion = int(np.sum(high_motion)) - - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("High-Motion Window Detection") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Sliding windows (len=%d, stride=%d): " - "%d total, %d fully valid (NaN-free), " - "%d high-motion (median speed >= %dth percentile)", - config.window_len, - config.stride, - len(window_starts), - num_valid_windows, - num_high_motion, - int(config.pct_thresh), - ) - - if num_high_motion == 0: - logger.error("No high-motion windows found.") - return None - - run_starts, run_ends, _run_lengths = _detect_runs( - high_motion, config.min_run_len - ) - if len(run_starts) == 0: - logger.error("No runs met min_run_len=%d.", config.min_run_len) - return None - - segments_raw = _convert_runs_to_segments( - run_starts, run_ends, window_starts, config.window_len - ) - segments = _merge_segments(segments_raw) - - log_info("Detected %d merged high-motion segment(s):", len(segments)) - for i, (start, end) in enumerate(segments): - log_info(" Segment %d: frames %d - %d", i + 1, start, end) - - return segments - - -def _select_tier2_frames( - segments: np.ndarray, - tier2_valid: np.ndarray, - num_frames: int, - log_info, - log_warning, -) -> tuple[np.ndarray, np.ndarray, int] | None: - """Filter segment frames to tier-2 valid only. - - Returns (selected_frames, selected_seg_id, num_selected) or None. - """ - selected_frames, selected_seg_id = _filter_segments_tier2( - segments, tier2_valid - ) - - num_tier1_in_segs = sum( - np.sum( - (np.arange(num_frames) >= s[0]) & (np.arange(num_frames) <= s[1]) - ) - for s in segments - ) - num_selected = len(selected_frames) - - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Tier-2 Filtering on High-Motion Segments") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Frames in high-motion segments (any tier): %d", num_tier1_in_segs - ) - log_info( - "Tier-2 valid frames retained (all keypoints present): " - "%d (%.1f%% of segment frames)", - num_selected, - 100 * num_selected / max(num_tier1_in_segs, 1), - ) - - retention = num_selected / max(num_tier1_in_segs, 1) - if retention < 0.3: - log_warning( - "Tier 2 discards > 70%% of segment frames - " - "body model may be unrepresentative." - ) - - if num_selected < 2: - logger.error("Not enough tier-2 valid frames in selected segments.") - return None - - return selected_frames, selected_seg_id, num_selected - - -def _run_clustering_and_pca( - centered_skeletons: np.ndarray, - frame_sel: _FrameSelection, - config: _ValidateAPConfig, - log_info, - log_warning, -) -> dict | None: - """Run postural analysis, clustering, and per-cluster PCA. - - Returns dict with primary_result, cluster_results, - num_clusters, primary_cluster, or None on failure. - """ - rmsd_matrix = _compute_pairwise_rmsd(centered_skeletons) - var_ratio, within_rmsds, between_rmsds, var_ratio_override = ( - _compute_postural_variance_ratio(rmsd_matrix, frame_sel.seg_ids) - ) - - rmsd_stats = { - "within": within_rmsds, - "between": between_rmsds, - "var_ratio": var_ratio, - "override": var_ratio_override, - } - _log_postural_consistency( - rmsd_stats, - config, - frame_sel.count, - log_info, - log_warning, - ) - - cluster_labels, num_clusters, primary_cluster = _decide_and_run_clustering( - centered_skeletons, - var_ratio, - frame_sel.count, - config, - log_info, - ) - - cluster_results = [] - for c in range(num_clusters): - cluster_mask = cluster_labels == c - cr = _compute_cluster_pca_and_anterior( - centered_skeletons, - cluster_mask, - frame_sel.frames, - frame_sel.seg_ids, - frame_sel.segments, - frame_sel.bbox_centroids, - ) - cluster_results.append(cr) - - pr = cluster_results[primary_cluster] - if not pr["valid"]: - logger.error("Primary cluster has invalid PCA result.") - return None - - return { - "primary_result": pr, - "cluster_results": cluster_results, - "num_clusters": num_clusters, - "primary_cluster": primary_cluster, - } - - -def _log_postural_consistency( - rmsd_stats, - config, - num_selected, - log_info, - log_warning, -): - """Log postural consistency check results.""" - within_rmsds = rmsd_stats["within"] - between_rmsds = rmsd_stats["between"] - var_ratio = rmsd_stats["var_ratio"] - var_ratio_override = rmsd_stats["override"] - - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Postural Consistency Check") - log_info("────────────────────────────────────────────────────────────") - - if len(within_rmsds) > 0: - log_info( - "Within-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", - np.mean(within_rmsds), - np.std(within_rmsds), - len(within_rmsds), - ) - else: - log_info("Within-segment RMSD: N/A (no within-segment pairs)") - - if len(between_rmsds) > 0: - log_info( - "Between-segment RMSD: mean=%.4f, std=%.4f (n=%d pairs)", - np.mean(between_rmsds), - np.std(between_rmsds), - len(between_rmsds), - ) - log_info( - "Variance ratio (between/within): %.2f (threshold=%.2f)", - var_ratio, - config.postural_var_ratio_thresh, - ) - if var_ratio_override: - log_info( - " (Conservative override to zero: within-segment variance " - "is zero or no within-segment pairs)" - ) - else: - log_info("Between-segment RMSD: N/A (single segment)") - log_info("Variance ratio: N/A") - - do_clustering = ( - var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 - ) - if do_clustering: - log_info(" -> Variance ratio exceeds threshold. Running clustering.") - elif var_ratio > config.postural_var_ratio_thresh and num_selected < 6: - log_info( - " -> Variance ratio exceeds threshold but too few frames (%d) " - "for clustering.", - num_selected, - ) - else: - log_info(" -> Postural consistency acceptable. Using global average.") - - -def _decide_and_run_clustering( - centered_skeletons, - var_ratio, - num_selected, - config, - log_info, -): - """Decide whether to cluster; run k-medoids if triggered.""" - do_clustering = ( - var_ratio > config.postural_var_ratio_thresh and num_selected >= 6 - ) - - if not do_clustering: - return np.zeros(num_selected, dtype=int), 1, 0 - - ( - cluster_labels, - num_clusters, - primary_cluster, - best_silhouette, - silhouette_scores, - ) = _perform_postural_clustering(centered_skeletons, config.max_clusters) - - for k, sil in silhouette_scores: - if np.isnan(sil): - log_info(" k=%d: clustering failed.", k) - else: - log_info(" k=%d: mean silhouette = %.4f", k, sil) - - if num_clusters > 1: - cluster_counts = np.bincount(cluster_labels, minlength=num_clusters) - log_info( - " Selected k=%d clusters (silhouette=%.4f). " - "Primary cluster=%d (%d frames)", - num_clusters, - best_silhouette, - primary_cluster + 1, - cluster_counts[primary_cluster], - ) - else: - log_info( - " Clustering did not improve separation (best_sil=%.4f). " - "Using global average.", - best_silhouette, - ) - - return cluster_labels, num_clusters, primary_cluster - - -def _log_anterior_report( - pr, - cluster_results, - num_clusters, - primary_cluster, - config, - log_info, - log_warning, -): - """Log anterior direction detection and cluster agreement.""" - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Anterior Direction Inference (Velocity Voting)") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Centroid velocity projections onto PC1: " - "%d positive (+PC1), %d negative (−PC1)", - pr["num_positive"], - pr["num_negative"], - ) - - vote_margin_str = f"Vote margin M: {pr['vote_margin']:.4f}" - if pr["vote_margin"] < config.confidence_floor: - vote_margin_str += ( - f" ** BELOW CONFIDENCE FLOOR ({config.confidence_floor:.2f}) " - "— anterior assignment is unreliable **" - ) - log_warning( - "Vote margin M = %.4f is below confidence floor %.2f — " - "anterior assignment is unreliable.", - pr["vote_margin"], - config.confidence_floor, - ) - log_info(vote_margin_str) - log_info( - "Resultant length R: %.4f (0 = omnidirectional, 1 = unidirectional)", - pr["resultant_length"], - ) - log_info( - "Inferred anterior direction: %sPC1 " - "(strict majority; ties default to −PC1)", - "+" if pr["anterior_sign"] > 0 else "−", - ) - - if num_clusters <= 1: - return - - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Inter-Cluster Anterior Polarity Agreement") - log_info("────────────────────────────────────────────────────────────") - signs = [cr["anterior_sign"] for cr in cluster_results if cr["valid"]] - if len(set(signs)) == 1: - log_info( - "All %d clusters AGREE on anterior polarity.", - num_clusters, - ) - else: - log_info( - "DISAGREEMENT: clusters assign different anterior polarities." - ) - for c, cr in enumerate(cluster_results): - if cr["valid"]: - log_info( - " Cluster %d (%d frames): anterior = %sPC1, " - "vote_margin M = %.4f, resultant_length R = %.4f", - c + 1, - cr["n_frames"], - "+" if cr["anterior_sign"] > 0 else "−", - cr["vote_margin"], - cr["resultant_length"], - ) - log_info( - " Primary result from cluster %d (largest).", - primary_cluster + 1, - ) - - -def _log_step1_report(pair_report, config, valid_nodes, log_info): - """Log Step 1 lateral filter results.""" - num_valid = len(valid_nodes) - num_candidates = len(pair_report.sorted_candidate_nodes) - step1_loss = 1 - num_candidates / max(num_valid, 1) - - pass_strs = [] - fail_strs = [] - for node_i in valid_nodes: - lat_norm = pair_report.lateral_offsets_norm[node_i] - if lat_norm <= config.lateral_thresh: - pass_strs.append(f"{node_i}({lat_norm:.2f})") - else: - fail_strs.append(f"{node_i}({lat_norm:.2f})") - - log_info("") - log_info( - "Step 1 — Lateral Alignment Filter (lateral_thresh=%.2f): " - "%d of %d valid nodes pass [loss=%.0f%%]", - config.lateral_thresh, - num_candidates, - num_valid, - 100 * step1_loss, - ) - log_info( - " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" - ) - if pass_strs: - log_info(" PASS: %s", ", ".join(pass_strs)) - if fail_strs: - log_info(" FAIL: %s", ", ".join(fail_strs)) - - -def _log_step2_report(pair_report, config, log_info): - """Log Step 2 opposite-sides results.""" - num_candidates = len(pair_report.sorted_candidate_nodes) - num_possible_pairs = num_candidates * (num_candidates - 1) // 2 - num_valid_pairs = len(pair_report.valid_pairs) - step2_loss = 1 - num_valid_pairs / max(num_possible_pairs, 1) - m = pair_report.midpoint_pc1 - - plus_strs = [] - minus_strs = [] - for node_i in pair_report.sorted_candidate_nodes: - pc1_rel = pair_report.pc1_coords[node_i] - m - if pc1_rel > 0: - plus_strs.append(f"{node_i}({pc1_rel:+.1f})") - else: - minus_strs.append(f"{node_i}({pc1_rel:+.1f})") - - log_info("") - log_info( - "Step 2 — Opposite-Sides Constraint (AP midpoint=%.2f): " - "%d of %d candidate pairs on opposite sides [loss=%.0f%%]", - m, - num_valid_pairs, - num_possible_pairs, - 100 * step2_loss, - ) - if plus_strs: - log_info(" + side (anterior of midpoint): %s", ", ".join(plus_strs)) - if minus_strs: - log_info(" - side (posterior of midpoint): %s", ", ".join(minus_strs)) - - -def _log_step3_report(pair_report, config, log_info): - """Log Step 3 distal/proximal classification results.""" - num_distal = len(pair_report.distal_pairs) - num_proximal = len(pair_report.proximal_pairs) - num_valid_pairs = len(pair_report.valid_pairs) - step3_distal_frac = num_distal / max(num_valid_pairs, 1) - - log_info("") - log_info( - "Step 3 — Distal/Proximal Classification (edge_thresh=%.2f): " - "%d distal, %d proximal [distal fraction=%.0f%%]", - config.edge_thresh, - num_distal, - num_proximal, - 100 * step3_distal_frac, - ) - - for idx in range(num_valid_pairs): - node_i, node_j = pair_report.valid_pairs[idx] - d_i = pair_report.midline_dist_norm[node_i] - d_j = pair_report.midline_dist_norm[node_j] - min_d = min(d_i, d_j) - sep = pair_report.valid_pairs_internode_dist[idx] - status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" - log_info( - " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", - node_i, - node_j, - min_d, - sep, - status, - ) - - -def _log_loss_summary( - step1_loss, step2_loss, step3_frac, step1_failed, step2_failed, log_info -): - """Log cumulative filtering loss summary.""" - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Filtering Loss Summary") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Step 1 (Lateral Filter): %.0f%% of valid nodes eliminated", - 100 * step1_loss, - ) - if not step1_failed: - log_info( - "Step 2 (Opposite-Sides): %.0f%% of candidate pairs eliminated", - 100 * step2_loss, - ) - if not step1_failed and not step2_failed: - log_info( - "Step 3 (Distal/Proximal): %.0f%% of surviving pairs are distal", - 100 * step3_frac, - ) - - -def _log_order_check( - pair_report, from_idx, to_idx, from_name, to_name, log_info -): - """Log AP ordering check for the input pair.""" - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("Order Check: is from_node posterior to to_node?") - log_info("────────────────────────────────────────────────────────────") - ap_from = pair_report.ap_coords[from_idx] - ap_to = pair_report.ap_coords[to_idx] - if np.isnan(ap_from) or np.isnan(ap_to): - log_info("Order check: cannot evaluate (invalid node coordinates)") - return - - log_info( - "AP coords: from_node %s[%d]=%.2f, to_node %s[%d]=%.2f", - from_name, - from_idx, - ap_from, - to_name, - to_idx, - ap_to, - ) - neon_green = "\033[38;5;46m" - crimson = "\033[38;5;9m" - reset = "\033[0m" - if pair_report.input_pair_order_matches_inference: - log_info( - f"{neon_green}[%d, %d]: CONSISTENT — inference agrees that " - f"from_node is posterior (lower AP coord), " - f"to_node is anterior{reset}", - from_idx, - to_idx, - ) - else: - log_info( - f"{crimson}[%d, %d]: INCONSISTENT — inference suggests " - f"from_node is anterior (higher AP coord), " - f"to_node is posterior{reset}", - from_idx, - to_idx, - ) - log_info( - " -> Inferred posterior→anterior order would be [%d, %d]", - to_idx, - from_idx, - ) - - -def _log_input_node_status( - pair_report, - config, - from_idx, - to_idx, - log_info, -): - """Log whether each input node passed the lateral filter.""" - lat_from = pair_report.lateral_offsets_norm[from_idx] - lat_to = pair_report.lateral_offsets_norm[to_idx] - from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh - to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh - - if from_pass and to_pass: - return - - fail_nodes = [] - if not from_pass: - fail_nodes.append(f"{from_idx}({lat_from:.2f})") - if not to_pass: - fail_nodes.append(f"{to_idx}({lat_to:.2f})") - log_info( - " -> Input node(s) FAILED lateral filter: %s", - ", ".join(fail_nodes), - ) - - -def _log_step2_step3_details( - pair_report, - config, - from_idx, - to_idx, - num_candidates, - log_info, -): - """Log Step 2 and Step 3 results when Step 1 succeeded. - - Returns (step2_loss, step3_frac, step2_failed). - """ - _log_step2_report(pair_report, config, log_info) - - if ( - pair_report.input_pair_in_candidates - and not pair_report.input_pair_opposite_sides - ): - log_info(" -> Input nodes on SAME side of AP midpoint") - - num_possible = num_candidates * (num_candidates - 1) // 2 - step2_loss = 1 - len(pair_report.valid_pairs) / max(num_possible, 1) - step2_failed = pair_report.failure_step.startswith("Step 2") - - if step2_failed: - log_info("") - log_info("Step 3: not evaluated (Step 2 failed)") - return step2_loss, 0.0, True - - step3_frac = _log_step3_with_proximal_check( - pair_report, - config, - from_idx, - to_idx, - log_info, - ) - return step2_loss, step3_frac, False - - -def _log_step3_with_proximal_check( - pair_report, - config, - from_idx, - to_idx, - log_info, -): - """Log Step 3 results and check input pair proximal status. - - Returns step3_frac. - """ - _log_step3_report(pair_report, config, log_info) - num_distal = len(pair_report.distal_pairs) - num_valid_pairs = len(pair_report.valid_pairs) - step3_frac = num_distal / max(num_valid_pairs, 1) - - is_candidate = pair_report.input_pair_in_candidates - is_opposite = pair_report.input_pair_opposite_sides - is_proximal = not pair_report.input_pair_is_distal - if is_candidate and is_opposite and is_proximal: - d_from = pair_report.midline_dist_norm[from_idx] - d_to = pair_report.midline_dist_norm[to_idx] - log_info( - " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", - min(d_from, d_to), - config.edge_thresh, - ) - - return step3_frac - - -def _log_pair_evaluation( - pair_report, - config, - from_idx, - to_idx, - from_name, - to_name, - log_info, -): - """Log the complete AP node pair evaluation report.""" - log_info("") - log_info("────────────────────────────────────────────────────────────") - log_info("AP Node-Pair Filter Cascade (3-Step Evaluation)") - log_info("────────────────────────────────────────────────────────────") - log_info( - "Input pair: [%d, %d] (%s → %s, claimed posterior → anterior)", - from_idx, - to_idx, - from_name, - to_name, - ) - - step1_failed = pair_report.failure_step.startswith("Step 1") - - valid_nodes = np.where(~np.isnan(pair_report.lateral_offsets_norm))[0] - num_candidates = len(pair_report.sorted_candidate_nodes) - step1_loss = 1 - num_candidates / max(len(valid_nodes), 1) - - _log_step1_report(pair_report, config, valid_nodes, log_info) - _log_input_node_status(pair_report, config, from_idx, to_idx, log_info) - - step2_loss = 0.0 - step3_frac = 0.0 - step2_failed = False - - if step1_failed: - log_info("") - log_info("Step 2-3: not evaluated (Step 1 failed)") - else: - step2_loss, step3_frac, step2_failed = _log_step2_step3_details( - pair_report, - config, - from_idx, - to_idx, - num_candidates, - log_info, - ) - - _log_loss_summary( - step1_loss, - step2_loss, - step3_frac, - step1_failed, - step2_failed, - log_info, - ) - _log_order_check( - pair_report, - from_idx, - to_idx, - from_name, - to_name, - log_info, - ) - - -# ── Main validation function ──────────────────────────────────────────── - - -def _validate_ap( - data: xr.DataArray, - from_node: Hashable, - to_node: Hashable, - config: _ValidateAPConfig | None = None, - verbose: bool = True, -) -> dict: - """Validate an anterior-posterior keypoint pair using body-axis inference. - - This function implements a prior-free body-axis inference pipeline that: - 1. Identifies high-motion segments using tiered validity and sliding - windows - 2. Optionally performs postural clustering via k-medoids - 3. Infers the anterior direction using velocity projection voting - 4. Evaluates the candidate AP keypoint pair through a 3-step filter - cascade - - Parameters - ---------- - data : xarray.DataArray - Position data for a single individual. - from_node : int or str - Index or name of the posterior keypoint. - to_node : int or str - Index or name of the anterior keypoint. - config : _ValidateAPConfig, optional - Configuration parameters. If None, uses defaults. - verbose : bool, default=True - If True, log detailed validation output to console. - - Returns - ------- - dict - Validation results including success, anterior_sign, - vote_margin, resultant_length, pair_report, etc. - - """ - if config is None: - config = _ValidateAPConfig() - - log_lines: list[str] = [] - - def _log_info(msg, *args): - """Log an informational message.""" - line = msg % args if args else msg - log_lines.append(line) - if verbose: - print(line) - - def _log_warning(msg, *args): - """Log a warning message with ANSI coloring.""" - orange = "\033[38;5;214m" - reset = "\033[0m" - line = f"{orange}WARNING: {msg % args if args else msg}{reset}" - log_lines.append(line) - if verbose: - print(line) - - # Prepare inputs - ( - keypoints, - from_idx, - to_idx, - from_name, - to_name, - _keypoint_names, - num_frames, - ) = _prepare_validation_inputs(data, from_node, to_node) - - n_keypoints = keypoints.shape[1] - result: dict = { - "success": False, - "anterior_sign": 0, - "vote_margin": 0.0, - "resultant_length": 0.0, - "num_selected_frames": 0, - "num_clusters": 1, - "primary_cluster": 0, - "pair_report": _APNodePairReport(), - "PC1": np.array([1.0, 0.0]), - "PC2": np.array([0.0, 1.0]), - "avg_skeleton": np.full((n_keypoints, 2), np.nan), - "error_msg": "", - "log_lines": log_lines, - } - - # Motion segmentation - seg = _run_motion_segmentation( - keypoints, - num_frames, - config, - _log_info, - _log_warning, - ) - if seg is None: - result["error_msg"] = "Motion segmentation failed." - return result - - # Tier-2 frame selection - t2 = _select_tier2_frames( - seg["segments"], - seg["tier2_valid"], - num_frames, - _log_info, - _log_warning, - ) - if t2 is None: - result["error_msg"] = "Not enough tier-2 valid frames." - return result - selected_frames, selected_seg_id, num_selected = t2 - result["num_selected_frames"] = num_selected - - # Build centered skeletons - _selected_centroids, centered_skeletons = _build_centered_skeletons( - keypoints, selected_frames - ) - - # Bundle frame selection data - frame_sel = _FrameSelection( - frames=selected_frames, - seg_ids=selected_seg_id, - segments=seg["segments"], - bbox_centroids=seg["bbox_centroids"], - count=num_selected, - ) - - # Postural clustering + PCA + anterior inference - pca = _run_clustering_and_pca( - centered_skeletons, - frame_sel, - config, - _log_info, - _log_warning, - ) - if pca is None: - result["error_msg"] = "Primary cluster PCA failed." - return result - - pr = pca["primary_result"] - result["anterior_sign"] = pr["anterior_sign"] - result["vote_margin"] = pr["vote_margin"] - result["resultant_length"] = pr["resultant_length"] - result["circ_mean_dir"] = pr["circ_mean_dir"] - result["vel_projs_pc1"] = pr["vel_projs_pc1"] - result["PC1"] = pr["PC1"] - result["PC2"] = pr["PC2"] - result["avg_skeleton"] = pr["avg_skeleton"] - result["num_clusters"] = pca["num_clusters"] - result["primary_cluster"] = pca["primary_cluster"] - - # Log anterior inference - _log_anterior_report( - pr, - pca["cluster_results"], - pca["num_clusters"], - pca["primary_cluster"], - config, - _log_info, - _log_warning, - ) - - # AP node-pair evaluation - pair_report = _evaluate_ap_node_pair( - pr["avg_skeleton"], - pr["PC1"], - pr["anterior_sign"], - pr["valid_shape_rows"], - from_idx, - to_idx, - config, - ) - result["pair_report"] = pair_report - - _log_pair_evaluation( - pair_report, - config, - from_idx, - to_idx, - from_name, - to_name, - _log_info, - ) - - result["success"] = True - return result diff --git a/tests/test_unit/test_kinematics/test_body_axis.py b/tests/test_unit/test_kinematics/test_body_axis.py new file mode 100644 index 000000000..9b316b6c1 --- /dev/null +++ b/tests/test_unit/test_kinematics/test_body_axis.py @@ -0,0 +1,62 @@ +# test_body_axis.py +"""Tests for the body axis validation module.""" + +from typing import Any + +import pytest + +from movement.kinematics.body_axis import ValidateAPConfig + + +class TestValidateAPConfig: + """Tests for the ValidateAPConfig dataclass parameter validation.""" + + @pytest.mark.parametrize( + ("field", "value"), + [ + ("min_valid_frac", -0.1), + ("min_valid_frac", 1.1), + ("window_len", 0), + ("window_len", -5), + ("window_len", 2.5), + ("stride", 0), + ("stride", -1), + ("stride", 1.5), + ("pct_thresh", -1), + ("pct_thresh", 101), + ("min_run_len", 0), + ("min_run_len", -1), + ("min_run_len", 1.5), + ("postural_var_ratio_thresh", 0), + ("postural_var_ratio_thresh", -1), + ("max_clusters", 0), + ("max_clusters", 2.5), + ("confidence_floor", -0.1), + ("confidence_floor", 1.1), + ("lateral_thresh", -0.1), + ("lateral_thresh", 1.1), + ("edge_thresh", -0.1), + ("edge_thresh", 1.1), + ], + ) + def test_invalid_config_values_raise(self, field: str, value: Any) -> None: + """Invalid config values should raise ValueError.""" + kwargs = {field: value} + with pytest.raises(ValueError, match="must be"): + ValidateAPConfig(**kwargs) + + def test_valid_config_does_not_raise(self) -> None: + """Valid config values should not raise any error.""" + # Should not raise + ValidateAPConfig( + min_valid_frac=0.5, + window_len=10, + stride=2, + pct_thresh=50.0, + min_run_len=2, + postural_var_ratio_thresh=1.5, + max_clusters=3, + confidence_floor=0.2, + lateral_thresh=0.3, + edge_thresh=0.2, + ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index a102e1b5d..809b40bdf 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -1,8 +1,6 @@ # test_collective.py """Tests for the collective behavior metrics module.""" -from typing import Any - import numpy as np import pytest import xarray as xr @@ -339,64 +337,6 @@ def test_empty_keypoints_dimension_raises_in_displacement_mode(self): kinematics.compute_polarization(data) -class TestValidateAPConfig: - """Tests for the _ValidateAPConfig dataclass parameter validation.""" - - @pytest.mark.parametrize( - ("field", "value"), - [ - ("min_valid_frac", -0.1), - ("min_valid_frac", 1.1), - ("window_len", 0), - ("window_len", -5), - ("window_len", 2.5), - ("stride", 0), - ("stride", -1), - ("stride", 1.5), - ("pct_thresh", -1), - ("pct_thresh", 101), - ("min_run_len", 0), - ("min_run_len", -1), - ("min_run_len", 1.5), - ("postural_var_ratio_thresh", 0), - ("postural_var_ratio_thresh", -1), - ("max_clusters", 0), - ("max_clusters", 2.5), - ("confidence_floor", -0.1), - ("confidence_floor", 1.1), - ("lateral_thresh", -0.1), - ("lateral_thresh", 1.1), - ("edge_thresh", -0.1), - ("edge_thresh", 1.1), - ], - ) - def test_invalid_config_values_raise(self, field: str, value: Any) -> None: - """Invalid config values should raise ValueError.""" - from movement.kinematics.collective import _ValidateAPConfig - - kwargs = {field: value} - with pytest.raises(ValueError, match="must be"): - _ValidateAPConfig(**kwargs) - - def test_valid_config_does_not_raise(self) -> None: - """Valid config values should not raise any error.""" - from movement.kinematics.collective import _ValidateAPConfig - - # Should not raise - _ValidateAPConfig( - min_valid_frac=0.5, - window_len=10, - stride=2, - pct_thresh=50.0, - min_run_len=2, - postural_var_ratio_thresh=1.5, - max_clusters=3, - confidence_floor=0.2, - lateral_thresh=0.3, - edge_thresh=0.2, - ) - - class TestComputePolarizationBehavior: """Tests for polarization computation behavior.""" From b1df3b96f09f934022c452ce208f2bfaea37036e Mon Sep 17 00:00:00 2001 From: khan-u Date: Sun, 5 Apr 2026 02:55:08 -0700 Subject: [PATCH 19/21] update(body_axis) new config params optimized via grid search --- movement/kinematics/body_axis.py | 332 ++++++++++++++---- .../test_kinematics/test_body_axis.py | 16 +- 2 files changed, 282 insertions(+), 66 deletions(-) diff --git a/movement/kinematics/body_axis.py b/movement/kinematics/body_axis.py index 054af965e..2b51ba4a2 100644 --- a/movement/kinematics/body_axis.py +++ b/movement/kinematics/body_axis.py @@ -56,12 +56,19 @@ class ValidateAPConfig: confidence_floor : float, default=0.1 Vote margin below which the anterior inference is flagged as unreliable. - lateral_thresh : float, default=0.4 - Normalized lateral offset ceiling for the Step 1 lateral alignment - filter. - edge_thresh : float, default=0.3 - Normalized midpoint distance floor for the Step 3 distal/proximal - classification. + lateral_thresh_pct : float, default=50.0 + Percentile threshold for Step 1 lateral alignment filter. Keypoints + with effective lateral score above this percentile are eliminated. + edge_thresh_pct : float, default=70.0 + Percentile threshold for Step 3 distal/proximal classification. + Pairs where both nodes have normalized midpoint distance above this + percentile are classified as "distal". + lateral_var_weight : float, default=1.0 + Weight for lateral (PC2) position variance penalty in Step 1 filter. + Higher values penalize keypoints that swing side-to-side. + longitudinal_var_weight : float, default=0.5 + Weight for longitudinal (PC1) position variance penalty in Step 1 + filter. Higher values penalize keypoints that move along the AP axis. """ @@ -73,16 +80,16 @@ class ValidateAPConfig: postural_var_ratio_thresh: float = 2.0 max_clusters: int = 4 confidence_floor: float = 0.1 - lateral_thresh: float = 0.4 - edge_thresh: float = 0.3 + lateral_thresh_pct: float = 50.0 + edge_thresh_pct: float = 70.0 + lateral_var_weight: float = 1.0 + longitudinal_var_weight: float = 0.5 def __post_init__(self) -> None: """Validate configuration parameters.""" for name in ( "min_valid_frac", "confidence_floor", - "lateral_thresh", - "edge_thresh", ): value = getattr(self, name) if not (0 <= value <= 1): @@ -97,10 +104,12 @@ def __post_init__(self) -> None: f"{name} must be a positive integer, got {value}" ) - if not (0 <= self.pct_thresh <= 100): - raise ValueError( - f"pct_thresh must be between 0 and 100, got {self.pct_thresh}" - ) + for name in ("pct_thresh", "lateral_thresh_pct", "edge_thresh_pct"): + value = getattr(self, name) + if not (0 <= value <= 100): + raise ValueError( + f"{name} must be between 0 and 100, got {value}" + ) if self.postural_var_ratio_thresh <= 0: raise ValueError( @@ -108,6 +117,11 @@ def __post_init__(self) -> None: f"got {self.postural_var_ratio_thresh}" ) + for name in ("lateral_var_weight", "longitudinal_var_weight"): + value = getattr(self, name) + if value < 0: + raise ValueError(f"{name} must be non-negative, got {value}") + @dataclass class FrameSelection: @@ -194,6 +208,16 @@ class APNodePairReport: Minimum lateral offset among valid keypoints. lateral_offset_max : float Maximum lateral offset among valid keypoints. + lateral_std : np.ndarray + Per-keypoint standard deviation of lateral (PC2) position across + selected frames. Higher values indicate more swing/instability. + lateral_std_norm : np.ndarray + Normalized lateral std (0 = most stable, 1 = most variable). + longitudinal_std : np.ndarray + Per-keypoint standard deviation of longitudinal (PC1) position + across selected frames. Higher values indicate more AP movement. + longitudinal_std_norm : np.ndarray + Normalized longitudinal std (0 = most stable, 1 = most variable). midpoint_pc1 : float AP reference midpoint (average of min and max PC1 projections). pc1_min : float @@ -205,9 +229,11 @@ class APNodePairReport: midline_dist_max : float Maximum absolute distance from midpoint. distal_pairs : np.ndarray - Array of distal pairs (both nodes at or above edge_thresh). + Array of distal pairs (both nodes at or above + edge_thresh_pct percentile). proximal_pairs : np.ndarray - Array of proximal pairs (at least one node below edge_thresh). + Array of proximal pairs (at least one node below + edge_thresh_pct percentile). max_separation_distal_nodes : np.ndarray Node indices of the maximum-separation distal pair, ordered so that element 0 is posterior (lower AP coord) and element 1 @@ -255,6 +281,12 @@ class APNodePairReport: ) lateral_offset_min: float = np.nan lateral_offset_max: float = np.nan + lateral_std: np.ndarray = field(default_factory=lambda: np.array([])) + lateral_std_norm: np.ndarray = field(default_factory=lambda: np.array([])) + longitudinal_std: np.ndarray = field(default_factory=lambda: np.array([])) + longitudinal_std_norm: np.ndarray = field( + default_factory=lambda: np.array([]) + ) midpoint_pc1: float = np.nan pc1_min: float = np.nan pc1_max: float = np.nan @@ -801,10 +833,6 @@ def compute_postural_variance_ratio( return var_ratio, within_rmsds, between_rmsds, var_ratio_override -# Clustering (k-medoids) -# ─────────────────────── - - def _update_medoid_for_cluster( cluster: int, labels: np.ndarray, @@ -823,6 +851,10 @@ def _update_medoid_for_cluster( return cluster_indices[best_idx] +# K-Medoids Clustering +# ───────────────────── + + def kmedoids( data: np.ndarray, k: int, @@ -1218,6 +1250,21 @@ def compute_cluster_pca_and_anterior( result["proj_pc1"] = proj_pc1 result["proj_pc2"] = proj_pc2 + # Compute per-keypoint position std across frames (for variance penalty) + lateral_per_frame = skels_c @ PC2 # (n_frames, n_keypoints) + longitudinal_per_frame = skels_c @ PC1 # (n_frames, n_keypoints) + + lateral_std = np.full(n_keypoints, np.nan) + longitudinal_std = np.full(n_keypoints, np.nan) + lateral_std[valid_shape_rows] = np.nanstd( + lateral_per_frame[:, valid_shape_rows], axis=0 + ) + longitudinal_std[valid_shape_rows] = np.nanstd( + longitudinal_per_frame[:, valid_shape_rows], axis=0 + ) + result["lateral_std"] = lateral_std + result["longitudinal_std"] = longitudinal_std + velocities = compute_cluster_velocities( selected_frames, selected_seg_id, @@ -1275,15 +1322,22 @@ def compute_node_projections( proj_pc1_valid = report.pc1_coords[valid_shape_rows] report.pc1_min = float(np.min(proj_pc1_valid)) report.pc1_max = float(np.max(proj_pc1_valid)) - report.midpoint_pc1 = (report.pc1_min + report.pc1_max) / 2 + # Use centroid (mean) instead of geometric center for better robustness + report.midpoint_pc1 = float(np.mean(proj_pc1_valid)) def apply_lateral_filter( report: APNodePairReport, valid_idx: np.ndarray, - lateral_thresh: float, + lateral_std: np.ndarray, + longitudinal_std: np.ndarray, + config: ValidateAPConfig, ) -> np.ndarray | None: - """Step 1: Filter keypoints by normalized lateral offset. + """Step 1: Filter keypoints by normalized lateral offset + variance. + + Filters keypoints based on their mean lateral offset from the body axis, + optionally penalized by lateral and longitudinal position variance. + Higher variance indicates less stable keypoints (e.g., swinging tail tip). Returns sorted candidate node indices, or None on failure. @@ -1294,17 +1348,54 @@ def apply_lateral_filter( report.lateral_offset_min = d_min report.lateral_offset_max = d_max + # Normalize using min-max scaling for better discrimination + # d_norm=0 means minimum lateral offset, d_norm=1 means maximum if d_max > d_min: d_norm = (d_valid - d_min) / (d_max - d_min) - report.lateral_offsets_norm[valid_idx] = d_norm - keep_mask = d_norm <= lateral_thresh else: - report.lateral_offsets_norm[valid_idx] = np.zeros(len(d_valid)) - keep_mask = np.ones(len(d_valid), dtype=bool) + d_norm = np.zeros(len(d_valid)) + report.lateral_offsets_norm[valid_idx] = d_norm + + # Normalize lateral std to [0, 1] + lat_std_valid = lateral_std[valid_idx] + lat_std_max = float(np.nanmax(lat_std_valid)) + if lat_std_max > 0: + lat_std_norm = lat_std_valid / lat_std_max + else: + lat_std_norm = np.zeros(len(lat_std_valid)) + report.lateral_std[valid_idx] = lat_std_valid + report.lateral_std_norm[valid_idx] = lat_std_norm + + # Normalize longitudinal std to [0, 1] + long_std_valid = longitudinal_std[valid_idx] + long_std_max = float(np.nanmax(long_std_valid)) + if long_std_max > 0: + long_std_norm = long_std_valid / long_std_max + else: + long_std_norm = np.zeros(len(long_std_valid)) + report.longitudinal_std[valid_idx] = long_std_valid + report.longitudinal_std_norm[valid_idx] = long_std_norm + + # Combined effective lateral score: + # mean_offset + lateral_var_weight * lateral_std + # + long_var_weight * long_std + effective_lateral = ( + d_norm + + config.lateral_var_weight * lat_std_norm + + config.longitudinal_var_weight * long_std_norm + ) + + # Use percentile threshold for robust filtering + # This adapts to the distribution of scores in each dataset + percentile_thresh = float( + np.percentile(effective_lateral, config.lateral_thresh_pct) + ) + keep_mask = effective_lateral <= percentile_thresh candidate_idx = np.nonzero(keep_mask)[0] candidates = valid_idx[candidate_idx] - sorted_order = np.argsort(d_valid[candidate_idx]) + # Sort by effective lateral score (lowest = closest to axis + most stable) + sorted_order = np.argsort(effective_lateral[candidate_idx]) candidates = candidates[sorted_order] report.sorted_candidate_nodes = candidates.copy() @@ -1395,7 +1486,7 @@ def classify_distal_proximal( pairs: np.ndarray, seps: np.ndarray, valid_shape_rows: np.ndarray, - edge_thresh: float, + edge_thresh_pct: float, ) -> np.ndarray: """Step 3: Classify pairs as distal or proximal. Returns pair_is_distal.""" m = report.midpoint_pc1 @@ -1408,32 +1499,58 @@ def classify_distal_proximal( else: report.midline_dist_norm = np.zeros(len(report.pc1_coords)) + # Collect midline distances for candidate nodes in pairs + candidate_nodes = np.unique(pairs.flatten()) + candidate_dists = report.midline_dist_norm[candidate_nodes] + candidate_dists = candidate_dists[~np.isnan(candidate_dists)] + + # Use percentile threshold for robust distal/proximal classification + if len(candidate_dists) > 0: + percentile_thresh = float( + np.percentile(candidate_dists, edge_thresh_pct) + ) + else: + percentile_thresh = 0.5 # Fallback when no candidates + pair_is_distal = np.zeros(len(pairs), dtype=bool) for k in range(len(pairs)): i, j = pairs[k] pair_is_distal[k] = ( min(report.midline_dist_norm[i], report.midline_dist_norm[j]) - >= edge_thresh + >= percentile_thresh ) report.distal_pairs = pairs[pair_is_distal] report.proximal_pairs = pairs[~pair_is_distal] + # Compute weighted separations that penalize + # high-variance (unstable) nodes. + # This favors stable body-core keypoints over swinging extremities. + # weighted_sep = sep * (1 - avg_variance_of_pair) + lateral_std_norm = report.lateral_std_norm + weighted_seps = np.zeros(len(seps)) + for k in range(len(pairs)): + i, j = pairs[k] + std_i = lateral_std_norm[i] if not np.isnan(lateral_std_norm[i]) else 0 + std_j = lateral_std_norm[j] if not np.isnan(lateral_std_norm[j]) else 0 + avg_std = (std_i + std_j) / 2 + weighted_seps[k] = seps[k] * (1 - avg_std) + if len(seps) > 0: - idx_max = int(np.argmax(seps)) + idx_max = int(np.argmax(weighted_seps)) report.max_separation_nodes = order_pair_by_ap( pairs[idx_max], report.ap_coords ) report.max_separation = seps[idx_max] if np.any(pair_is_distal): - distal_seps = seps[pair_is_distal] + distal_weighted_seps = weighted_seps[pair_is_distal] distal_pairs_only = pairs[pair_is_distal] - idx_max_distal = int(np.argmax(distal_seps)) + idx_max_distal = int(np.argmax(distal_weighted_seps)) report.max_separation_distal_nodes = order_pair_by_ap( distal_pairs_only[idx_max_distal], report.ap_coords ) - report.max_separation_distal = distal_seps[idx_max_distal] + report.max_separation_distal = seps[pair_is_distal][idx_max_distal] return pair_is_distal @@ -1651,6 +1768,8 @@ def evaluate_ap_node_pair( pc1_vec: np.ndarray, anterior_sign: int, valid_shape_rows: np.ndarray, + lateral_std: np.ndarray, + longitudinal_std: np.ndarray, from_node: int, to_node: int, config: ValidateAPConfig, @@ -1667,6 +1786,13 @@ def evaluate_ap_node_pair( Inferred anterior direction (+1 or -1 relative to PC1). valid_shape_rows : np.ndarray Boolean array indicating valid (non-NaN) keypoints. + lateral_std : np.ndarray + Per-keypoint standard deviation of lateral (PC2) position across + selected frames. Used to penalize high-swing keypoints. + longitudinal_std : np.ndarray + Per-keypoint standard deviation of longitudinal (PC1) position + across selected frames. Used to penalize keypoints with high + AP movement variance. from_node : int Index of the input from_node (body_axis_keypoints origin, claimed posterior). 0-indexed. @@ -1674,7 +1800,8 @@ def evaluate_ap_node_pair( Index of the input to_node (body_axis_keypoints target, claimed anterior). 0-indexed. config : ValidateAPConfig - Configuration with ``lateral_thresh`` and ``edge_thresh``. + Configuration with ``lateral_thresh_pct``, ``edge_thresh_pct``, and + variance weight parameters. Returns ------- @@ -1688,6 +1815,10 @@ def evaluate_ap_node_pair( report.ap_coords = np.full(n_keypoints, np.nan) report.lateral_offsets = np.full(n_keypoints, np.nan) report.lateral_offsets_norm = np.full(n_keypoints, np.nan) + report.lateral_std = np.full(n_keypoints, np.nan) + report.lateral_std_norm = np.full(n_keypoints, np.nan) + report.longitudinal_std = np.full(n_keypoints, np.nan) + report.longitudinal_std_norm = np.full(n_keypoints, np.nan) report.midline_dist_norm = np.full(n_keypoints, np.nan) for node, label in [(from_node, "from_node"), (to_node, "to_node")]: @@ -1714,7 +1845,9 @@ def evaluate_ap_node_pair( to_node, ) - candidates = apply_lateral_filter(report, valid_idx, config.lateral_thresh) + candidates = apply_lateral_filter( + report, valid_idx, lateral_std, longitudinal_std, config + ) if candidates is None: return report @@ -1734,7 +1867,7 @@ def evaluate_ap_node_pair( pairs, seps, valid_shape_rows, - config.edge_thresh, + config.edge_thresh_pct, ) input_in_valid, input_idx = check_input_pair_in_valid( @@ -2058,7 +2191,6 @@ def run_clustering_and_pca( frame_sel: FrameSelection, config: ValidateAPConfig, log_info, - log_warning, ) -> dict | None: """Run postural analysis, clustering, and per-cluster PCA. @@ -2380,31 +2512,80 @@ def log_step1_report(pair_report, config, valid_nodes, log_info): num_candidates = len(pair_report.sorted_candidate_nodes) step1_loss = 1 - num_candidates / max(num_valid, 1) + # Compute effective lateral score for each node + effective_scores = [] + node_details = [] + for node_i in valid_nodes: + d_norm = pair_report.lateral_offsets_norm[node_i] + lat_std_norm = pair_report.lateral_std_norm[node_i] + long_std_norm = pair_report.longitudinal_std_norm[node_i] + + # Handle NaN values + lat_std_norm = 0.0 if np.isnan(lat_std_norm) else lat_std_norm + long_std_norm = 0.0 if np.isnan(long_std_norm) else long_std_norm + + effective = ( + d_norm + + config.lateral_var_weight * lat_std_norm + + config.longitudinal_var_weight * long_std_norm + ) + effective_scores.append(effective) + node_details.append((node_i, effective)) + + # Compute percentile threshold from distribution + if len(effective_scores) > 0: + percentile_thresh = float( + np.percentile(effective_scores, config.lateral_thresh_pct) + ) + else: + percentile_thresh = 0.0 + pass_strs = [] fail_strs = [] - for node_i in valid_nodes: - lat_norm = pair_report.lateral_offsets_norm[node_i] - if lat_norm <= config.lateral_thresh: - pass_strs.append(f"{node_i}({lat_norm:.2f})") + for node_i, effective in node_details: + detail = f"{node_i}(eff={effective:.2f})" + if effective <= percentile_thresh: + pass_strs.append(detail) else: - fail_strs.append(f"{node_i}({lat_norm:.2f})") + fail_strs.append(detail) log_info("") log_info( - "Step 1 - Lateral Alignment Filter (lateral_thresh=%.2f): " - "%d of %d valid nodes pass [loss=%.0f%%]", - config.lateral_thresh, + "Step 1 - Lateral Alignment Filter: %d of %d valid nodes pass " + "[loss=%.0f%%]", num_candidates, num_valid, 100 * step1_loss, ) log_info( - " Scale: 0.00 = nearest to body axis, 1.00 = farthest from body axis" + " Config: lateral_thresh_pct=%.0f, lateral_var_weight=%.2f, " + "longitudinal_var_weight=%.2f", + config.lateral_thresh_pct, + config.lateral_var_weight, + config.longitudinal_var_weight, + ) + log_info( + " Percentile threshold: %.0fth pct = %.2f", + config.lateral_thresh_pct, + percentile_thresh, + ) + log_info( + " Score = d_norm + %.2f×lat_std_norm + %.2f×long_std_norm", + config.lateral_var_weight, + config.longitudinal_var_weight, ) if pass_strs: - log_info(" PASS: %s", ", ".join(pass_strs)) + log_info( + " PASS (score <= %.2f): %s", + percentile_thresh, + ", ".join(pass_strs), + ) if fail_strs: - log_info(" FAIL: %s", ", ".join(fail_strs)) + log_info( + " FAIL (score > %.2f): %s", + percentile_thresh, + ", ".join(fail_strs), + ) def log_step2_report(pair_report, _config, log_info): @@ -2446,15 +2627,31 @@ def log_step3_report(pair_report, config, log_info): num_valid_pairs = len(pair_report.valid_pairs) step3_distal_frac = num_distal / max(num_valid_pairs, 1) + # Compute percentile threshold from candidate distances + candidate_nodes = np.unique(pair_report.valid_pairs.flatten()) + candidate_dists = pair_report.midline_dist_norm[candidate_nodes] + candidate_dists = candidate_dists[~np.isnan(candidate_dists)] + if len(candidate_dists) > 0: + percentile_thresh = float( + np.percentile(candidate_dists, config.edge_thresh_pct) + ) + else: + percentile_thresh = 0.5 + log_info("") log_info( - "Step 3 - Distal/Proximal Classification (edge_thresh=%.2f): " + "Step 3 - Distal/Proximal Classification (edge_thresh_pct=%.0f): " "%d distal, %d proximal [distal fraction=%.0f%%]", - config.edge_thresh, + config.edge_thresh_pct, num_distal, num_proximal, 100 * step3_distal_frac, ) + log_info( + " Percentile threshold: %.0fth pct = %.2f", + config.edge_thresh_pct, + percentile_thresh, + ) for idx in range(num_valid_pairs): node_i, node_j = pair_report.valid_pairs[idx] @@ -2462,7 +2659,7 @@ def log_step3_report(pair_report, config, log_info): d_j = pair_report.midline_dist_norm[node_j] min_d = min(d_i, d_j) sep = pair_report.valid_pairs_internode_dist[idx] - status = "DISTAL" if min_d >= config.edge_thresh else "PROXIMAL" + status = "DISTAL" if min_d >= percentile_thresh else "PROXIMAL" log_info( " [%d,%d]: min_d=%.2f, sep=%.2f [%s]", node_i, @@ -2551,19 +2748,21 @@ def log_input_node_status( log_info, ): """Log whether each input node passed the lateral filter.""" - lat_from = pair_report.lateral_offsets_norm[from_idx] - lat_to = pair_report.lateral_offsets_norm[to_idx] - from_pass = not np.isnan(lat_from) and lat_from <= config.lateral_thresh - to_pass = not np.isnan(lat_to) and lat_to <= config.lateral_thresh + # Check if input nodes are in the candidates that passed Step 1 + candidates = pair_report.sorted_candidate_nodes + from_pass = from_idx in candidates + to_pass = to_idx in candidates if from_pass and to_pass: return fail_nodes = [] if not from_pass: - fail_nodes.append(f"{from_idx}({lat_from:.2f})") + lat_from = pair_report.lateral_offsets_norm[from_idx] + fail_nodes.append(f"{from_idx}(lat={lat_from:.2f})") if not to_pass: - fail_nodes.append(f"{to_idx}({lat_to:.2f})") + lat_to = pair_report.lateral_offsets_norm[to_idx] + fail_nodes.append(f"{to_idx}(lat={lat_to:.2f})") log_info( " -> Input node(s) FAILED lateral filter: %s", ", ".join(fail_nodes), @@ -2633,10 +2832,20 @@ def log_step3_with_proximal_check( if is_candidate and is_opposite and is_proximal: d_from = pair_report.midline_dist_norm[from_idx] d_to = pair_report.midline_dist_norm[to_idx] + # Compute percentile threshold for context + candidate_nodes = np.unique(pair_report.valid_pairs.flatten()) + candidate_dists = pair_report.midline_dist_norm[candidate_nodes] + candidate_dists = candidate_dists[~np.isnan(candidate_dists)] + if len(candidate_dists) > 0: + pct_thresh = float( + np.percentile(candidate_dists, config.edge_thresh_pct) + ) + else: + pct_thresh = 0.5 log_info( " -> Input pair is PROXIMAL (min_d=%.2f < %.2f)", min(d_from, d_to), - config.edge_thresh, + pct_thresh, ) return step3_frac @@ -2776,7 +2985,6 @@ def _log_warning(msg, *args): frame_sel, config, _log_info, - _log_warning, ) if pca is None: result["error_msg"] = "Primary cluster PCA failed." @@ -2791,6 +2999,8 @@ def _log_warning(msg, *args): result["PC1"] = pr["PC1"] result["PC2"] = pr["PC2"] result["avg_skeleton"] = pr["avg_skeleton"] + result["lateral_std"] = pr["lateral_std"] + result["longitudinal_std"] = pr["longitudinal_std"] result["num_clusters"] = pca["num_clusters"] result["primary_cluster"] = pca["primary_cluster"] @@ -2811,6 +3021,8 @@ def _log_warning(msg, *args): pr["PC1"], pr["anterior_sign"], pr["valid_shape_rows"], + pr["lateral_std"], + pr["longitudinal_std"], from_idx, to_idx, config, diff --git a/tests/test_unit/test_kinematics/test_body_axis.py b/tests/test_unit/test_kinematics/test_body_axis.py index 9b316b6c1..52657a528 100644 --- a/tests/test_unit/test_kinematics/test_body_axis.py +++ b/tests/test_unit/test_kinematics/test_body_axis.py @@ -33,10 +33,12 @@ class TestValidateAPConfig: ("max_clusters", 2.5), ("confidence_floor", -0.1), ("confidence_floor", 1.1), - ("lateral_thresh", -0.1), - ("lateral_thresh", 1.1), - ("edge_thresh", -0.1), - ("edge_thresh", 1.1), + ("lateral_thresh_pct", -1), + ("lateral_thresh_pct", 101), + ("edge_thresh_pct", -1), + ("edge_thresh_pct", 101), + ("lateral_var_weight", -0.1), + ("longitudinal_var_weight", -0.1), ], ) def test_invalid_config_values_raise(self, field: str, value: Any) -> None: @@ -57,6 +59,8 @@ def test_valid_config_does_not_raise(self) -> None: postural_var_ratio_thresh=1.5, max_clusters=3, confidence_floor=0.2, - lateral_thresh=0.3, - edge_thresh=0.2, + lateral_thresh_pct=50.0, + edge_thresh_pct=70.0, + lateral_var_weight=1.0, + longitudinal_var_weight=0.0, ) From 15fd84dc8eeb69e8fb7c4beb2ab66504e684383a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Apr 2026 09:57:21 +0000 Subject: [PATCH 20/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/kinematics/body_axis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/movement/kinematics/body_axis.py b/movement/kinematics/body_axis.py index 2b51ba4a2..925d89226 100644 --- a/movement/kinematics/body_axis.py +++ b/movement/kinematics/body_axis.py @@ -1377,7 +1377,7 @@ def apply_lateral_filter( report.longitudinal_std_norm[valid_idx] = long_std_norm # Combined effective lateral score: - # mean_offset + lateral_var_weight * lateral_std + # mean_offset + lateral_var_weight * lateral_std # + long_var_weight * long_std effective_lateral = ( d_norm From 2e828f4aa426e5e179fd38992e0f53e906ba799a Mon Sep 17 00:00:00 2001 From: khan-u Date: Tue, 7 Apr 2026 20:31:06 -0700 Subject: [PATCH 21/21] fix(test_collective): nomenclature perpendicular -> cardinal direcctions to resolve potential confusion --- .../test_kinematics/test_collective.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py index 809b40bdf..2fc20ff31 100644 --- a/tests/test_unit/test_kinematics/test_collective.py +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -119,7 +119,7 @@ def partial_alignment_positions() -> xr.DataArray: @pytest.fixture -def perpendicular_positions() -> xr.DataArray: +def cardinal_directions_positions() -> xr.DataArray: """Four individuals moving in cardinal directions (+x, -x, +y, -y).""" data = np.array( [ @@ -351,11 +351,13 @@ def test_opposite_motion_gives_zero(self, opposite_positions): polarization = kinematics.compute_polarization(opposite_positions) assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) - def test_perpendicular_cardinal_directions_give_zero( - self, perpendicular_positions + def test_four_cardinal_directions_cancel_to_zero( + self, cardinal_directions_positions ): """Polarization is 0.0 when four individuals move in cardinal dirs.""" - polarization = kinematics.compute_polarization(perpendicular_positions) + polarization = kinematics.compute_polarization( + cardinal_directions_positions + ) assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) def test_partial_alignment_matches_expected_magnitude( @@ -1017,21 +1019,21 @@ def test_mean_angle_partial_alignment_matches_vector_average( def test_mean_angle_is_nan_when_net_vector_cancels( self, opposite_positions, - perpendicular_positions, + cardinal_directions_positions, ): """Mean angle is NaN when heading vectors cancel out.""" pol_opposite, angle_opposite = kinematics.compute_polarization( opposite_positions, return_angle=True, ) - pol_perp, angle_perp = kinematics.compute_polarization( - perpendicular_positions, + pol_cardinal, angle_cardinal = kinematics.compute_polarization( + cardinal_directions_positions, return_angle=True, ) assert np.allclose(pol_opposite.values[1:], 0.0, atol=1e-10) - assert np.allclose(pol_perp.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_cardinal.values[1:], 0.0, atol=1e-10) assert np.all(np.isnan(angle_opposite.values[1:])) - assert np.all(np.isnan(angle_perp.values[1:])) + assert np.all(np.isnan(angle_cardinal.values[1:])) def test_mean_angle_rotates_with_global_rotation( self,