@@ -16,58 +16,31 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
1616
1717 autodiff_alg = ADTypes. dense_ad (gpu_safe_autodiff (alg_autodiff (alg), u))
1818
19- # If `t` isn't dual-able by ForwardDiff (e.g. some unitful scalar types),
20- # differentiate w.r.t. its underlying primitive value and rescale.
21- # This avoids constructing `Dual{...,t}`.
22- vt = SciMLBase. value (t)
23- is_forwarddiff_backend = autodiff_alg isa AutoForwardDiff ||
24- (autodiff_alg isa DI. AutoForwardFromPrimitive && autodiff_alg. backend isa AutoForwardDiff)
25-
26- # If `t` isn’t directly differentiable by the backend (ForwardDiff can’t dualize it,
27- # or FiniteDiff would mix dimensionless `relstep` with dimensionful `absstep` at t=0),
28- # differentiate w.r.t. the primitive value `vt = value(t)` and rescale.
29- if (is_forwarddiff_backend && ! ForwardDiff. can_dual (typeof (t)) && vt isa Real) ||
30- (autodiff_alg isa AutoFiniteDiff && vt isa Real && typeof (vt) != typeof (t))
31- ut = oneunit (t)
32- tf_scaled! (y, τ) = tf (y, τ * ut)
33- if integrator. iter == 1
34- try
35- DI. derivative! (tf_scaled!, linsolve_tmp, dT, autodiff_alg, vt)
36- catch e
37- throw (FirstAutodiffTgradError (e))
38- end
39- else
40- DI. derivative! (tf_scaled!, linsolve_tmp, dT, autodiff_alg, vt)
41- end
42- dT ./= ut
43- OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
44- else
45- # Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work
46- t = autodiff_alg isa AutoForwardDiff ? convert (eltype (dT), t) : t
19+ # Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work
20+ t = autodiff_alg isa AutoForwardDiff ? convert (eltype (dT), t) : t
4721
48- grad_config_tup = cache. grad_config
22+ grad_config_tup = cache. grad_config
4923
50- if autodiff_alg isa AutoFiniteDiff
51- grad_config = diffdir (integrator) > 0 ? grad_config_tup[1 ] :
52- grad_config_tup[2 ]
53- else
54- grad_config = grad_config_tup[1 ]
55- end
24+ if autodiff_alg isa AutoFiniteDiff
25+ grad_config = diffdir (integrator) > 0 ? grad_config_tup[1 ] :
26+ grad_config_tup[2 ]
27+ else
28+ grad_config = grad_config_tup[1 ]
29+ end
5630
57- if integrator. iter == 1
58- try
59- DI. derivative! (
60- tf, linsolve_tmp, dT, grad_config, autodiff_alg, t
61- )
62- catch e
63- throw (FirstAutodiffTgradError (e))
64- end
65- else
66- DI. derivative! (tf, linsolve_tmp, dT, grad_config, autodiff_alg, t)
31+ if integrator. iter == 1
32+ try
33+ DI. derivative! (
34+ tf, linsolve_tmp, dT, grad_config, autodiff_alg, t
35+ )
36+ catch e
37+ throw (FirstAutodiffTgradError (e))
6738 end
68-
69- OrdinaryDiffEqCore . increment_nf! (integrator . stats, 1 )
39+ else
40+ DI . derivative! (tf, linsolve_tmp, dT, grad_config, autodiff_alg, t )
7041 end
42+
43+ OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
7144 end
7245 end
7346
@@ -92,33 +65,14 @@ function calc_tderivative(integrator, cache)
9265 autodiff_alg = SciMLBase. @set autodiff_alg. dir = diffdir (integrator)
9366 end
9467
95- vt = SciMLBase. value (t)
96- is_forwarddiff_backend = autodiff_alg isa AutoForwardDiff ||
97- (autodiff_alg isa DI. AutoForwardFromPrimitive && autodiff_alg. backend isa AutoForwardDiff)
98-
99- if (is_forwarddiff_backend && ! ForwardDiff. can_dual (typeof (t)) && vt isa Real) ||
100- (autodiff_alg isa AutoFiniteDiff && vt isa Real && typeof (vt) != typeof (t))
101- ut = oneunit (t)
102- tf_scaled (τ) = tf (τ * ut)
103- if integrator. iter == 1
104- try
105- dT = DI. derivative (tf_scaled, autodiff_alg, vt) ./ ut
106- catch e
107- throw (FirstAutodiffTgradError (e))
108- end
109- else
110- dT = DI. derivative (tf_scaled, autodiff_alg, vt) ./ ut
111- end
112- else
113- if integrator. iter == 1
114- try
115- dT = DI. derivative (tf, autodiff_alg, t)
116- catch e
117- throw (FirstAutodiffTgradError (e))
118- end
119- else
68+ if integrator. iter == 1
69+ try
12070 dT = DI. derivative (tf, autodiff_alg, t)
71+ catch e
72+ throw (FirstAutodiffTgradError (e))
12173 end
74+ else
75+ dT = DI. derivative (tf, autodiff_alg, t)
12276 end
12377
12478 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
@@ -972,30 +926,7 @@ function build_J_W(
972926 elseif J isa StaticMatrix
973927 StaticWOperator (J, false )
974928 else
975- # For unitful eltypes (e.g. DynamicQuantities quantities), the downstream
976- # `calc_W` path uses `DiffEqBase.default_factorize(W)` which returns a
977- # wrapper factorization type. Seed the cache with the same factorization
978- # type to avoid type-instability / assignment conversion errors.
979- # Heuristic: DynamicQuantities quantity eltypes intentionally do not
980- # define `zero(::Type{<:Quantity})` (dimensions unknown at type-level).
981- # In that case, seed W using `default_factorize` so the cache type matches
982- # what `calc_W` will later produce.
983- unitful_eltype = try
984- zero (eltype (J))
985- false
986- catch
987- true
988- end
989- if unitful_eltype
990- # Seed an instance whose unit metadata matches the W/J matrices in
991- # Rosenbrock/SDIRK methods (units ~ 1/time). We only need the *type*
992- # here; the actual factorization is computed later in `calc_W`.
993- Aproto = Matrix {eltype(J)} (undef, 1 , 1 )
994- Aproto[1 ] = inv (oneunit (t))
995- DiffEqBase. default_factorize (Aproto)
996- else
997- ArrayInterface. lu_instance (J)
998- end
929+ ArrayInterface. lu_instance (J)
999930 end
1000931 end
1001932 return J, W
0 commit comments