Skip to content

Commit f976d8d

Browse files
fix: make differentiation wrappers robust for unitful jacobians
1 parent 3f28d4d commit f976d8d

File tree

3 files changed

+34
-127
lines changed

3 files changed

+34
-127
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
107107
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
108108

109109
[targets]
110-
test = ["DiffEqDevTools", "Random", "SafeTestsets", "SparseArrays", "Test", "Pkg", "DynamicQuantities", "Measurements", "OrdinaryDiffEqTsit5"]
110+
test = ["DiffEqDevTools", "SafeTestsets", "SparseArrays", "Test", "Pkg", "DynamicQuantities", "Measurements", "OrdinaryDiffEqTsit5"]
111111

112112
[extensions]
113113
OrdinaryDiffEqCoreMooncakeExt = "Mooncake"

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 27 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -16,58 +16,31 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
1616

1717
autodiff_alg = ADTypes.dense_ad(gpu_safe_autodiff(alg_autodiff(alg), u))
1818

19-
# If `t` isn't dual-able by ForwardDiff (e.g. some unitful scalar types),
20-
# differentiate w.r.t. its underlying primitive value and rescale.
21-
# This avoids constructing `Dual{...,t}`.
22-
vt = SciMLBase.value(t)
23-
is_forwarddiff_backend = autodiff_alg isa AutoForwardDiff ||
24-
(autodiff_alg isa DI.AutoForwardFromPrimitive && autodiff_alg.backend isa AutoForwardDiff)
25-
26-
# If `t` isn’t directly differentiable by the backend (ForwardDiff can’t dualize it,
27-
# or FiniteDiff would mix dimensionless `relstep` with dimensionful `absstep` at t=0),
28-
# differentiate w.r.t. the primitive value `vt = value(t)` and rescale.
29-
if (is_forwarddiff_backend && !ForwardDiff.can_dual(typeof(t)) && vt isa Real) ||
30-
(autodiff_alg isa AutoFiniteDiff && vt isa Real && typeof(vt) != typeof(t))
31-
ut = oneunit(t)
32-
tf_scaled!(y, τ) = tf(y, τ * ut)
33-
if integrator.iter == 1
34-
try
35-
DI.derivative!(tf_scaled!, linsolve_tmp, dT, autodiff_alg, vt)
36-
catch e
37-
throw(FirstAutodiffTgradError(e))
38-
end
39-
else
40-
DI.derivative!(tf_scaled!, linsolve_tmp, dT, autodiff_alg, vt)
41-
end
42-
dT ./= ut
43-
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
44-
else
45-
# Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work
46-
t = autodiff_alg isa AutoForwardDiff ? convert(eltype(dT), t) : t
19+
# Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work
20+
t = autodiff_alg isa AutoForwardDiff ? convert(eltype(dT), t) : t
4721

48-
grad_config_tup = cache.grad_config
22+
grad_config_tup = cache.grad_config
4923

50-
if autodiff_alg isa AutoFiniteDiff
51-
grad_config = diffdir(integrator) > 0 ? grad_config_tup[1] :
52-
grad_config_tup[2]
53-
else
54-
grad_config = grad_config_tup[1]
55-
end
24+
if autodiff_alg isa AutoFiniteDiff
25+
grad_config = diffdir(integrator) > 0 ? grad_config_tup[1] :
26+
grad_config_tup[2]
27+
else
28+
grad_config = grad_config_tup[1]
29+
end
5630

57-
if integrator.iter == 1
58-
try
59-
DI.derivative!(
60-
tf, linsolve_tmp, dT, grad_config, autodiff_alg, t
61-
)
62-
catch e
63-
throw(FirstAutodiffTgradError(e))
64-
end
65-
else
66-
DI.derivative!(tf, linsolve_tmp, dT, grad_config, autodiff_alg, t)
31+
if integrator.iter == 1
32+
try
33+
DI.derivative!(
34+
tf, linsolve_tmp, dT, grad_config, autodiff_alg, t
35+
)
36+
catch e
37+
throw(FirstAutodiffTgradError(e))
6738
end
68-
69-
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
39+
else
40+
DI.derivative!(tf, linsolve_tmp, dT, grad_config, autodiff_alg, t)
7041
end
42+
43+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
7144
end
7245
end
7346

@@ -92,33 +65,14 @@ function calc_tderivative(integrator, cache)
9265
autodiff_alg = SciMLBase.@set autodiff_alg.dir = diffdir(integrator)
9366
end
9467

95-
vt = SciMLBase.value(t)
96-
is_forwarddiff_backend = autodiff_alg isa AutoForwardDiff ||
97-
(autodiff_alg isa DI.AutoForwardFromPrimitive && autodiff_alg.backend isa AutoForwardDiff)
98-
99-
if (is_forwarddiff_backend && !ForwardDiff.can_dual(typeof(t)) && vt isa Real) ||
100-
(autodiff_alg isa AutoFiniteDiff && vt isa Real && typeof(vt) != typeof(t))
101-
ut = oneunit(t)
102-
tf_scaled(τ) = tf* ut)
103-
if integrator.iter == 1
104-
try
105-
dT = DI.derivative(tf_scaled, autodiff_alg, vt) ./ ut
106-
catch e
107-
throw(FirstAutodiffTgradError(e))
108-
end
109-
else
110-
dT = DI.derivative(tf_scaled, autodiff_alg, vt) ./ ut
111-
end
112-
else
113-
if integrator.iter == 1
114-
try
115-
dT = DI.derivative(tf, autodiff_alg, t)
116-
catch e
117-
throw(FirstAutodiffTgradError(e))
118-
end
119-
else
68+
if integrator.iter == 1
69+
try
12070
dT = DI.derivative(tf, autodiff_alg, t)
71+
catch e
72+
throw(FirstAutodiffTgradError(e))
12173
end
74+
else
75+
dT = DI.derivative(tf, autodiff_alg, t)
12276
end
12377

12478
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@@ -972,30 +926,7 @@ function build_J_W(
972926
elseif J isa StaticMatrix
973927
StaticWOperator(J, false)
974928
else
975-
# For unitful eltypes (e.g. DynamicQuantities quantities), the downstream
976-
# `calc_W` path uses `DiffEqBase.default_factorize(W)` which returns a
977-
# wrapper factorization type. Seed the cache with the same factorization
978-
# type to avoid type-instability / assignment conversion errors.
979-
# Heuristic: DynamicQuantities quantity eltypes intentionally do not
980-
# define `zero(::Type{<:Quantity})` (dimensions unknown at type-level).
981-
# In that case, seed W using `default_factorize` so the cache type matches
982-
# what `calc_W` will later produce.
983-
unitful_eltype = try
984-
zero(eltype(J))
985-
false
986-
catch
987-
true
988-
end
989-
if unitful_eltype
990-
# Seed an instance whose unit metadata matches the W/J matrices in
991-
# Rosenbrock/SDIRK methods (units ~ 1/time). We only need the *type*
992-
# here; the actual factorization is computed later in `calc_W`.
993-
Aproto = Matrix{eltype(J)}(undef, 1, 1)
994-
Aproto[1] = inv(oneunit(t))
995-
DiffEqBase.default_factorize(Aproto)
996-
else
997-
ArrayInterface.lu_instance(J)
998-
end
929+
ArrayInterface.lu_instance(J)
999930
end
1000931
end
1001932
return J, W

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -124,38 +124,14 @@ function jacobian(f::F, x::AbstractArray{<:Number}, integrator) where {F}
124124
autodiff_alg = dense
125125
end
126126

127-
if autodiff_alg isa AutoFiniteDiff && x isa AbstractArray && !isempty(x)
128-
# For some scalar-like types (notably runtime-unit quantities), FiniteDiff’s
129-
# step-size selection mixes a dimensionless `relstep` with a dimensionful state
130-
# `x[i]`, triggering DimensionErrors in `max(relstep*abs(x[i]), absstep)`.
131-
# Work around this by differentiating w.r.t. primitive values and rescaling.
132-
U = oneunit.(x)
133-
vx = SciMLBase.value.(x)
134-
f_scaled(v) = f(U .* v)
135-
136-
if integrator.iter == 1
137-
try
138-
jac = DI.jacobian(f_scaled, autodiff_alg, vx)
139-
catch e
140-
throw(FirstAutodiffJacError(e))
141-
end
142-
else
143-
jac = DI.jacobian(f_scaled, autodiff_alg, vx)
144-
end
145-
146-
@inbounds for j in axes(jac, 2)
147-
jac[:, j] ./= U[j]
148-
end
149-
else
150-
if integrator.iter == 1
151-
try
152-
jac = DI.jacobian(f, autodiff_alg, x)
153-
catch e
154-
throw(FirstAutodiffJacError(e))
155-
end
156-
else
127+
if integrator.iter == 1
128+
try
157129
jac = DI.jacobian(f, autodiff_alg, x)
130+
catch e
131+
throw(FirstAutodiffJacError(e))
158132
end
133+
else
134+
jac = DI.jacobian(f, autodiff_alg, x)
159135
end
160136

161137
return jac

0 commit comments

Comments
 (0)