Skip to content

Commit ff465b8

Browse files
Harsh SinghHarsh Singh
authored andcommitted
Fix FieldError for Magnus/Linear integrators with OrdinaryDiffEqDifferentiation (#3232)
Magnus integrators (MagnusGL6, MagnusGL8, etc.) and other OrdinaryDiffEqLinearExponentialAlgorithm subtypes have no `autodiff` field — only `krylov`, `m`, `iop`. When OrdinaryDiffEqDifferentiation is loaded (e.g. via DifferentialEquations.jl), the generic `_alg_autodiff(::OrdinaryDiffEqExponentialAlgorithm{CS,AD})` dispatch would call `alg.autodiff` on these types, causing a FieldError crash. Fix: - Import OrdinaryDiffEqLinearExponentialAlgorithm into OrdinaryDiffEqDifferentiation - Add _alg_autodiff(::OrdinaryDiffEqLinearExponentialAlgorithm) returning Val{false}(), intercepting calls before the generic ExponentialAlgorithm dispatch that accesses the nonexistent field - The existing prepare_alg override in OrdinaryDiffEqCore (line 298) already handles the prepare_alg path correctly Also adds regression tests verifying _alg_autodiff, prepare_alg, and forwarddiffs_model all work correctly for LinearExponentialAlgorithm subtypes.
1 parent a2c092a commit ff465b8

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
3636
OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm,
3737
OrdinaryDiffEqExponentialAlgorithm,
3838
OrdinaryDiffEqAdaptiveExponentialAlgorithm,
39+
OrdinaryDiffEqLinearExponentialAlgorithm,
3940
StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,
4041
StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
4142
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm,

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ function _alg_autodiff(
3232
) where {CS, AD, FDT, ST, CJ, Controller}
3333
return Val{AD}()
3434
end
35+
# OrdinaryDiffEqLinearExponentialAlgorithm subtypes (Magnus integrators, LieEuler,
36+
# CG methods, etc.) have NO autodiff field — their only fields are krylov, m, iop.
37+
# They must be excluded before the generic ExponentialAlgorithm dispatch below.
38+
function _alg_autodiff(::OrdinaryDiffEqLinearExponentialAlgorithm)
39+
return Val{false}()
40+
end
3541
function _alg_autodiff(
3642
alg::Union{
3743
OrdinaryDiffEqExponentialAlgorithm{CS, AD},
@@ -66,6 +72,8 @@ Base.@pure function determine_chunksize(u, CS)
6672
end
6773
end
6874

75+
76+
6977
function DiffEqBase.prepare_alg(
7078
alg::Union{
7179
OrdinaryDiffEqAdaptiveImplicitAlgorithm{

lib/OrdinaryDiffEqDifferentiation/test/differentiation_traits_tests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,31 @@ sol = solve(prob2, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 1)))
4343

4444
sol = solve(prob2, Rosenbrock23(autodiff = AutoFiniteDiff()))
4545
@test (good_sol[:, end], sol[:, end], rtol = 1.0e-2)
46+
47+
# Regression test for issue #3232:
48+
# OrdinaryDiffEqLinearExponentialAlgorithm subtypes (MagnusGL6, etc.)
49+
# have no `autodiff` field; _alg_autodiff and prepare_alg must not crash.
50+
using OrdinaryDiffEqDifferentiation: _alg_autodiff
51+
using OrdinaryDiffEqCore: OrdinaryDiffEqLinearExponentialAlgorithm
52+
using DiffEqBase: prepare_alg
53+
54+
struct MockMagnusAlg <: OrdinaryDiffEqLinearExponentialAlgorithm
55+
krylov::Bool
56+
m::Int
57+
iop::Int
58+
end
59+
60+
@testset "LinearExponentialAlgorithm autodiff traits (issue #3232)" begin
61+
mock = MockMagnusAlg(false, 30, 0)
62+
63+
# _alg_autodiff must return Val{false}() instead of accessing alg.autodiff
64+
@test _alg_autodiff(mock) == Val{false}()
65+
66+
# prepare_alg must return the algorithm unchanged (no AD preparation needed)
67+
u0 = ones(2)
68+
mock_prob = ODEProblem((du, u, p, t) -> du .= 0, u0, (0.0, 1.0))
69+
@test prepare_alg(mock, u0, nothing, mock_prob) === mock
70+
71+
# forwarddiffs_model must return false
72+
@test SciMLBase.forwarddiffs_model(mock) == false
73+
end

0 commit comments

Comments
 (0)