Skip to content

Commit 9bc1f17

Browse files
Optimization infrastructure: pluggable backends, losses, batched GPU processing (#531)
* add pluggable optimizer backends (lbfgs, nelder_mead, grid_search), convergence criteria, and `use_gradients` flag * rename num_iterations to max_iterations * add CacheSpec and disk-backed TransferFunctionCache for multi-tile TF reuse * add benchmark_optimizers for comparing optimization strategies" * add optimizer and TF cache benchmark examples" * add pluggable, normalized loss functions with 3D support * fix widget tests: use model_validator for loss config instead of discriminated union * Fix: import PrintLogger in optimize.py for nelder_mead/grid_search defaults * add batched (B,Z,Y,X) reconstruction support to all 4 models Accept both (Z,Y,X) and (B,Z,Y,X) inputs following PyTorch convention. Internal code unsqueezes/squeezes at boundaries so all operations use (B,Z,Y,X). Unbatched path is bit-exact with prior behavior. - util: pad_zyx_along_z, inten_normalization, inten_normalization_3D - isotropic_thin_3d: calculate_transfer_function (batched tilt angles), calculate_singular_system, apply_inverse_transfer_function, reconstruct - isotropic_fluorescent_thin_3d: apply_inverse_transfer_function, reconstruct - phase_thick_3d: apply_inverse_transfer_function, reconstruct - isotropic_fluorescent_thick_3d: apply_inverse_transfer_function, reconstruct - optics: compute_weak_object_transfer_function_2d now supports broadcasting - optics: generate_tilted_pupil documents batched output shape * add batched API input and per-tile optimization support API layer: phase and fluorescence apply_inverse_transfer_function accept list[xr.DataArray] for batched processing. Adds _wrap_output_tensor helper to reduce wrapping duplication. Optimizer: when data.ndim == 4, parameters become (B,) tensors for independent per-tile optimization. Standard Adam with (B,) tensors gives B independent optimizers (verified bit-exact vs sequential). Loss is summed per-tile; cross-tile gradient terms are zero since tiles are decoupled in the forward pass. * add batched reconstruction tests for all 4 models Tests B=4 batched vs sequential bit-exactness, B=1 matches unbatched, and per-tile tilt angles for isotropic_thin_3d. * restore torch.linalg.svd in calculate_singular_system The norm-based decomposition (U=identity) introduced a regression vs waveorder 3.0.0 by ignoring cross-channel coupling between absorption and phase transfer functions. Restoring the real SVD reduces the mean reconstruction error from 0.077 to 0.0005 on OPS data (verified against production phenotyping_phase_2d.zarr). torch.linalg.svd supports backpropagation through complex tensors, so gradient-based optimization still works. * add use_svd flag to calculate_singular_system for gradient compatibility torch.linalg.svd backward fails with complex tensors on some GPU types due to singular vector phase ambiguity. When gradients are needed (optimization), reconstruct() auto-selects the norm-based decomposition (use_svd=False) which supports backpropagation on all devices. When no gradients are needed (final reconstruction), full SVD is used for best accuracy. The no-grad path remains bit-exact with the previous SVD-only behavior. * catch SVD/NaN errors in optimization loop and revert to last good params * support batched parameters in calculate_transfer_function and apply_inverse_transfer_function * Remove TF caching and benchmarking harness Removes cache.py, benchmark.py, CacheSpec, and associated tests/examples. TF caching showed marginal speedup with ~24% accuracy loss (nearest-neighbor) and the benchmarking harness is better suited as standalone scripts. * Remove CacheSpec from OptimizableFloat * Add NAdam optimizer backend * Validate optimization method name * Use default lr=1.0 for L-BFGS line search step size * Eliminate L-BFGS double forward pass * Use np.linspace for grid search to avoid float-step accumulation
1 parent cadc965 commit 9bc1f17

24 files changed

+1928
-495
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,4 @@ waveorder/_version.py
155155
# example data
156156
/examples/data_temp/*
157157
/logs/*
158+
scripts/

docs/examples/optimization/optimize_fluorescence_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from waveorder.api import fluorescence
1010
from waveorder.optim import OptimizableFloat
11+
from waveorder.optim.losses import MidbandPowerLossSettings
1112

1213
# Ground truth parameters
1314
gt_z_offset = 0.6
@@ -32,8 +33,8 @@
3233
optimized_settings, recon = fluorescence.optimize(
3334
data,
3435
settings=opt_settings,
35-
num_iterations=50,
36-
midband_fractions=(0.01, 0.5),
36+
max_iterations=50,
37+
loss_settings=MidbandPowerLossSettings(midband_fractions=[0.01, 0.5]),
3738
log_dir=log_dir,
3839
log_images=True,
3940
)

docs/examples/optimization/optimize_phase_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from waveorder.api import phase
1313
from waveorder.optim import OptimizableFloat
14+
from waveorder.optim.losses import MidbandPowerLossSettings
1415

1516
# To use your own data instead of simulated data, create a CZYX xr.DataArray:
1617
#
@@ -53,8 +54,8 @@
5354
optimized_settings, recon = phase.optimize(
5455
data,
5556
settings=opt_settings,
56-
num_iterations=50,
57-
midband_fractions=(0.1, 0.5),
57+
max_iterations=50,
58+
loss_settings=MidbandPowerLossSettings(midband_fractions=[0.1, 0.5]),
5859
log_dir=log_dir,
5960
log_images=True,
6061
)

docs/examples/optimization/phase_2d_optimized.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ phase:
2424
reconstruction_algorithm: Tikhonov
2525
regularization_strength: 0.01
2626
optimization:
27-
num_iterations: 10
27+
max_iterations: 10
28+
method: adam # adam, lbfgs, nelder_mead, grid_search
2829
loss:
29-
type: midband_power
30-
midband_fractions: [0.125, 0.25]
30+
type: midband_power # midband_power, total_variation, laplacian_variance,
31+
midband_fractions: [0.125, 0.25] # normalized_variance, spectral_flatness
3132
log_dir: ./runs/phase_optim

tests/cli_tests/test_reconstruct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_optimization_cli(tmp_path):
270270
phase=settings.PhaseSettings(
271271
transfer_function=PhaseTFSettings(z_focus_offset={"init": 0, "lr": 0.1}),
272272
),
273-
optimization=OptimizationSettings(num_iterations=2),
273+
optimization=OptimizationSettings(max_iterations=2),
274274
)
275275
config_path = tmp_path / "optim.yml"
276276
utils.model_to_yaml(recon_settings, config_path)
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""Test batched reconstruction is bit-exact against sequential single-tile calls."""
2+
3+
import pytest
4+
import torch
5+
6+
from waveorder.models import (
7+
isotropic_fluorescent_thick_3d,
8+
isotropic_fluorescent_thin_3d,
9+
isotropic_thin_3d,
10+
phase_thick_3d,
11+
)
12+
13+
B = 4
14+
Z, Y, X = 5, 32, 32
15+
16+
17+
# --- isotropic_thin_3d ---
18+
19+
20+
class TestIsotropicThin3DBatched:
21+
"""Batched vs sequential tests for isotropic_thin_3d."""
22+
23+
@pytest.fixture
24+
def optical_params(self):
25+
return dict(
26+
yx_pixel_size=6.5 / 40,
27+
wavelength_illumination=0.532,
28+
index_of_refraction_media=1.33,
29+
numerical_aperture_illumination=0.9,
30+
numerical_aperture_detection=1.2,
31+
)
32+
33+
@pytest.fixture
34+
def z_position_list(self):
35+
return (-torch.arange(Z) + Z // 2).float() * 0.1
36+
37+
def test_batched_reconstruct_matches_sequential(self, optical_params, z_position_list):
38+
torch.manual_seed(42)
39+
bzyx = torch.randn(B, Z, Y, X)
40+
41+
# Sequential
42+
seq_abs, seq_phase = [], []
43+
for b in range(B):
44+
a, p = isotropic_thin_3d.reconstruct(bzyx[b], z_position_list=z_position_list, **optical_params)
45+
seq_abs.append(a)
46+
seq_phase.append(p)
47+
seq_abs = torch.stack(seq_abs)
48+
seq_phase = torch.stack(seq_phase)
49+
50+
# Batched
51+
bat_abs, bat_phase = isotropic_thin_3d.reconstruct(bzyx, z_position_list=z_position_list, **optical_params)
52+
53+
torch.testing.assert_close(bat_abs, seq_abs, atol=0, rtol=0)
54+
torch.testing.assert_close(bat_phase, seq_phase, atol=0, rtol=0)
55+
56+
def test_b1_matches_unbatched(self, optical_params, z_position_list):
57+
torch.manual_seed(42)
58+
zyx = torch.randn(Z, Y, X)
59+
60+
ref_abs, ref_phase = isotropic_thin_3d.reconstruct(zyx, z_position_list=z_position_list, **optical_params)
61+
62+
bat_abs, bat_phase = isotropic_thin_3d.reconstruct(
63+
zyx.unsqueeze(0), z_position_list=z_position_list, **optical_params
64+
)
65+
66+
torch.testing.assert_close(bat_abs.squeeze(0), ref_abs, atol=0, rtol=0)
67+
torch.testing.assert_close(bat_phase.squeeze(0), ref_phase, atol=0, rtol=0)
68+
69+
def test_batched_tf_with_per_tile_tilt(self, optical_params, z_position_list):
70+
torch.manual_seed(42)
71+
bzyx = torch.randn(B, Z, Y, X)
72+
73+
zeniths = torch.tensor([0.0, 0.05, 0.1, 0.15])
74+
azimuths = torch.tensor([0.0, 0.5, 1.0, 1.5])
75+
76+
# Sequential with per-tile tilt
77+
seq_abs, seq_phase = [], []
78+
for b in range(B):
79+
a, p = isotropic_thin_3d.reconstruct(
80+
bzyx[b],
81+
z_position_list=z_position_list,
82+
tilt_angle_zenith=zeniths[b].item(),
83+
tilt_angle_azimuth=azimuths[b].item(),
84+
**optical_params,
85+
)
86+
seq_abs.append(a)
87+
seq_phase.append(p)
88+
seq_abs = torch.stack(seq_abs)
89+
seq_phase = torch.stack(seq_phase)
90+
91+
# Batched TF with per-tile tilt
92+
abs_tf, phase_tf = isotropic_thin_3d.calculate_transfer_function(
93+
(Y, X),
94+
z_position_list=z_position_list,
95+
tilt_angle_zenith=zeniths,
96+
tilt_angle_azimuth=azimuths,
97+
**optical_params,
98+
)
99+
singular_system = isotropic_thin_3d.calculate_singular_system(abs_tf, phase_tf)
100+
bat_abs, bat_phase = isotropic_thin_3d.apply_inverse_transfer_function(bzyx, singular_system)
101+
102+
# Batched SVD produces slightly different singular values at
103+
# near-zero frequencies due to floating-point ordering, which gets
104+
# amplified by the Tikhonov regularized inverse. Match to within
105+
# 1% of the signal standard deviation.
106+
for b in range(B):
107+
atol_p = 0.01 * seq_phase[b].std()
108+
torch.testing.assert_close(bat_phase[b], seq_phase[b], rtol=1e-3, atol=atol_p)
109+
atol_a = 0.01 * seq_abs[b].std()
110+
torch.testing.assert_close(bat_abs[b], seq_abs[b], rtol=1e-3, atol=atol_a)
111+
112+
113+
# --- isotropic_fluorescent_thin_3d ---
114+
115+
116+
class TestFluorescentThin3DBatched:
117+
"""Batched vs sequential tests for isotropic_fluorescent_thin_3d."""
118+
119+
@pytest.fixture
120+
def optical_params(self):
121+
return dict(
122+
yx_pixel_size=6.5 / 40,
123+
wavelength_emission=0.532,
124+
index_of_refraction_media=1.33,
125+
numerical_aperture_detection=1.2,
126+
)
127+
128+
@pytest.fixture
129+
def z_position_list(self):
130+
return (-torch.arange(Z) + Z // 2).float() * 0.1
131+
132+
def test_batched_reconstruct_matches_sequential(self, optical_params, z_position_list):
133+
torch.manual_seed(42)
134+
bzyx = torch.randn(B, Z, Y, X).abs() + 1 # fluorescence is positive
135+
136+
seq = []
137+
for b in range(B):
138+
r = isotropic_fluorescent_thin_3d.reconstruct(bzyx[b], z_position_list=z_position_list, **optical_params)
139+
seq.append(r)
140+
seq = torch.stack(seq)
141+
142+
bat = isotropic_fluorescent_thin_3d.reconstruct(bzyx, z_position_list=z_position_list, **optical_params)
143+
144+
torch.testing.assert_close(bat, seq, atol=0, rtol=0)
145+
146+
def test_b1_matches_unbatched(self, optical_params, z_position_list):
147+
torch.manual_seed(42)
148+
zyx = torch.randn(Z, Y, X).abs() + 1
149+
150+
ref = isotropic_fluorescent_thin_3d.reconstruct(zyx, z_position_list=z_position_list, **optical_params)
151+
bat = isotropic_fluorescent_thin_3d.reconstruct(
152+
zyx.unsqueeze(0), z_position_list=z_position_list, **optical_params
153+
)
154+
155+
torch.testing.assert_close(bat.squeeze(0), ref, atol=0, rtol=0)
156+
157+
158+
# --- phase_thick_3d ---
159+
160+
161+
class TestPhaseThick3DBatched:
162+
"""Batched vs sequential tests for phase_thick_3d."""
163+
164+
@pytest.fixture
165+
def optical_params(self):
166+
return dict(
167+
yx_pixel_size=6.5 / 40,
168+
z_pixel_size=0.25,
169+
wavelength_illumination=0.532,
170+
z_padding=5,
171+
index_of_refraction_media=1.33,
172+
numerical_aperture_illumination=0.9,
173+
numerical_aperture_detection=1.2,
174+
)
175+
176+
def test_batched_reconstruct_matches_sequential(self, optical_params):
177+
torch.manual_seed(42)
178+
bzyx = torch.randn(B, Z, Y, X)
179+
180+
seq = []
181+
for b in range(B):
182+
r = phase_thick_3d.reconstruct(bzyx[b], **optical_params)
183+
seq.append(r)
184+
seq = torch.stack(seq)
185+
186+
bat = phase_thick_3d.reconstruct(bzyx, **optical_params)
187+
188+
torch.testing.assert_close(bat, seq, atol=1e-6, rtol=1e-6)
189+
190+
def test_b1_matches_unbatched(self, optical_params):
191+
torch.manual_seed(42)
192+
zyx = torch.randn(Z, Y, X)
193+
194+
ref = phase_thick_3d.reconstruct(zyx, **optical_params)
195+
bat = phase_thick_3d.reconstruct(zyx.unsqueeze(0), **optical_params)
196+
197+
torch.testing.assert_close(bat.squeeze(0), ref, atol=1e-6, rtol=1e-6)
198+
199+
200+
# --- isotropic_fluorescent_thick_3d ---
201+
202+
203+
class TestFluorescentThick3DBatched:
204+
"""Batched vs sequential tests for isotropic_fluorescent_thick_3d."""
205+
206+
@pytest.fixture
207+
def optical_params(self):
208+
return dict(
209+
yx_pixel_size=6.5 / 40,
210+
z_pixel_size=0.25,
211+
wavelength_emission=0.532,
212+
z_padding=5,
213+
index_of_refraction_media=1.33,
214+
numerical_aperture_detection=1.2,
215+
)
216+
217+
def test_batched_reconstruct_matches_sequential(self, optical_params):
218+
torch.manual_seed(42)
219+
bzyx = torch.randn(B, Z, Y, X).abs() + 1
220+
221+
seq = []
222+
for b in range(B):
223+
r = isotropic_fluorescent_thick_3d.reconstruct(bzyx[b], **optical_params)
224+
seq.append(r)
225+
seq = torch.stack(seq)
226+
227+
bat = isotropic_fluorescent_thick_3d.reconstruct(bzyx, **optical_params)
228+
229+
torch.testing.assert_close(bat, seq, atol=1e-6, rtol=1e-6)
230+
231+
def test_b1_matches_unbatched(self, optical_params):
232+
torch.manual_seed(42)
233+
zyx = torch.randn(Z, Y, X).abs() + 1
234+
235+
ref = isotropic_fluorescent_thick_3d.reconstruct(zyx, **optical_params)
236+
bat = isotropic_fluorescent_thick_3d.reconstruct(zyx.unsqueeze(0), **optical_params)
237+
238+
torch.testing.assert_close(bat.squeeze(0), ref, atol=1e-6, rtol=1e-6)

0 commit comments

Comments
 (0)