Skip to content

Commit 9eb0a25

Browse files
Harsh Singhclaude
authored andcommitted
refactor(Rosenbrock): migrate Rosenbrock23/32 to generic RodasTableau framework
Unify Rosenbrock23 and Rosenbrock32 with the existing generic Rosenbrock/Rodas infrastructure by expressing their coefficients as RodasTableau entries. This removes ~900 lines of duplicated cache structs, perform_step!, interpolation, and addsteps code. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fd4631d commit 9eb0a25

File tree

6 files changed

+74
-926
lines changed

6 files changed

+74
-926
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,3 @@
1-
function SciMLBase.interp_summary(
2-
::Type{cacheType},
3-
dense::Bool
4-
) where {
5-
cacheType <:
6-
Union{
7-
Rosenbrock23ConstantCache,
8-
Rosenbrock32ConstantCache,
9-
Rosenbrock23Cache,
10-
Rosenbrock32Cache,
11-
},
12-
}
13-
return dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
14-
"1st order linear"
15-
end
161
function SciMLBase.interp_summary(
172
::Type{cacheType},
183
dense::Bool

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 4 additions & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -62,268 +62,6 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <:
6262
interp_order::Int
6363
end
6464

65-
@cache mutable struct Rosenbrock23Cache{
66-
uType, rateType, uNoUnitsType, JType, WType,
67-
TabType, TFType, UFType, F, JCType, GCType,
68-
RTolType, A, AV, StepLimiter, StageLimiter,
69-
} <: RosenbrockMutableCache
70-
u::uType
71-
uprev::uType
72-
k₁::rateType
73-
k₂::rateType
74-
k₃::rateType
75-
du1::rateType
76-
du2::rateType
77-
f₁::rateType
78-
fsalfirst::rateType
79-
fsallast::rateType
80-
dT::rateType
81-
J::JType
82-
W::WType
83-
tmp::rateType
84-
atmp::uNoUnitsType
85-
weight::uNoUnitsType
86-
tab::TabType
87-
tf::TFType
88-
uf::UFType
89-
linsolve_tmp::rateType
90-
linsolve::F
91-
jac_config::JCType
92-
grad_config::GCType
93-
reltol::RTolType
94-
alg::A
95-
algebraic_vars::AV
96-
step_limiter!::StepLimiter
97-
stage_limiter!::StageLimiter
98-
end
99-
100-
@cache mutable struct Rosenbrock32Cache{
101-
uType, rateType, uNoUnitsType, JType, WType,
102-
TabType, TFType, UFType, F, JCType, GCType,
103-
RTolType, A, AV, StepLimiter, StageLimiter,
104-
} <: RosenbrockMutableCache
105-
u::uType
106-
uprev::uType
107-
k₁::rateType
108-
k₂::rateType
109-
k₃::rateType
110-
du1::rateType
111-
du2::rateType
112-
f₁::rateType
113-
fsalfirst::rateType
114-
fsallast::rateType
115-
dT::rateType
116-
J::JType
117-
W::WType
118-
tmp::rateType
119-
atmp::uNoUnitsType
120-
weight::uNoUnitsType
121-
tab::TabType
122-
tf::TFType
123-
uf::UFType
124-
linsolve_tmp::rateType
125-
linsolve::F
126-
jac_config::JCType
127-
grad_config::GCType
128-
reltol::RTolType
129-
alg::A
130-
algebraic_vars::AV
131-
step_limiter!::StepLimiter
132-
stage_limiter!::StageLimiter
133-
end
134-
135-
function alg_cache(
136-
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
137-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
138-
dt, reltol, p, calck,
139-
::Val{true}, verbose
140-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
141-
k₁ = zero(rate_prototype)
142-
k₂ = zero(rate_prototype)
143-
k₃ = zero(rate_prototype)
144-
du1 = zero(rate_prototype)
145-
du2 = zero(rate_prototype)
146-
# f₀ = zero(u) fsalfirst
147-
f₁ = zero(rate_prototype)
148-
fsalfirst = zero(rate_prototype)
149-
fsallast = zero(rate_prototype)
150-
dT = zero(rate_prototype)
151-
tmp = zero(rate_prototype)
152-
atmp = similar(u, uEltypeNoUnits)
153-
recursivefill!(atmp, false)
154-
weight = similar(u, uEltypeNoUnits)
155-
recursivefill!(weight, false)
156-
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
157-
tf = TimeGradientWrapper(f, uprev, p)
158-
uf = UJacobianWrapper(f, t, p)
159-
linsolve_tmp = zero(rate_prototype)
160-
161-
grad_config = build_grad_config(alg, f, tf, du1, t)
162-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
163-
164-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
165-
166-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
167-
Pl,
168-
Pr = wrapprecs(
169-
alg.precs(
170-
W, nothing, u, p, t, nothing, nothing, nothing,
171-
nothing
172-
)..., weight, tmp
173-
)
174-
linsolve = init(
175-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
176-
Pl = Pl, Pr = Pr,
177-
assumptions = LinearSolve.OperatorAssumptions(true),
178-
verbose = verbose.linear_verbosity
179-
)
180-
181-
algebraic_vars = f.mass_matrix === I ? nothing :
182-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
183-
184-
return Rosenbrock23Cache(
185-
u, uprev, k₁, k₂, k₃, du1, du2, f₁,
186-
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
187-
linsolve_tmp,
188-
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
189-
alg.stage_limiter!
190-
)
191-
end
192-
193-
function alg_cache(
194-
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
195-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
196-
dt, reltol, p, calck,
197-
::Val{true}, verbose
198-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
199-
k₁ = zero(rate_prototype)
200-
k₂ = zero(rate_prototype)
201-
k₃ = zero(rate_prototype)
202-
du1 = zero(rate_prototype)
203-
du2 = zero(rate_prototype)
204-
# f₀ = zero(u) fsalfirst
205-
f₁ = zero(rate_prototype)
206-
fsalfirst = zero(rate_prototype)
207-
fsallast = zero(rate_prototype)
208-
dT = zero(rate_prototype)
209-
tmp = zero(rate_prototype)
210-
atmp = similar(u, uEltypeNoUnits)
211-
recursivefill!(atmp, false)
212-
weight = similar(u, uEltypeNoUnits)
213-
recursivefill!(weight, false)
214-
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))
215-
216-
tf = TimeGradientWrapper(f, uprev, p)
217-
uf = UJacobianWrapper(f, t, p)
218-
linsolve_tmp = zero(rate_prototype)
219-
220-
grad_config = build_grad_config(alg, f, tf, du1, t)
221-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
222-
223-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
224-
225-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
226-
227-
Pl,
228-
Pr = wrapprecs(
229-
alg.precs(
230-
W, nothing, u, p, t, nothing, nothing, nothing,
231-
nothing
232-
)..., weight, tmp
233-
)
234-
linsolve = init(
235-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
236-
Pl = Pl, Pr = Pr,
237-
assumptions = LinearSolve.OperatorAssumptions(true),
238-
verbose = verbose.linear_verbosity
239-
)
240-
241-
algebraic_vars = f.mass_matrix === I ? nothing :
242-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
243-
244-
return Rosenbrock32Cache(
245-
u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
246-
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
247-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!
248-
)
249-
end
250-
251-
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
252-
RosenbrockConstantCache
253-
c₃₂::T
254-
d::T
255-
tf::TF
256-
uf::UF
257-
J::JType
258-
W::WType
259-
linsolve::F
260-
autodiff::AD
261-
end
262-
263-
function Rosenbrock23ConstantCache(
264-
::Type{T}, tf, uf, J, W, linsolve, autodiff
265-
) where {T}
266-
tab = Rosenbrock23Tableau(T)
267-
return Rosenbrock23ConstantCache(
268-
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff
269-
)
270-
end
271-
272-
function alg_cache(
273-
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
274-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
275-
dt, reltol, p, calck,
276-
::Val{false}, verbose
277-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
278-
tf = TimeDerivativeWrapper(f, u, p)
279-
uf = UDerivativeWrapper(f, t, p)
280-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
281-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
282-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
283-
return Rosenbrock23ConstantCache(
284-
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
285-
alg_autodiff(alg)
286-
)
287-
end
288-
289-
struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <:
290-
RosenbrockConstantCache
291-
c₃₂::T
292-
d::T
293-
tf::TF
294-
uf::UF
295-
J::JType
296-
W::WType
297-
linsolve::F
298-
autodiff::AD
299-
end
300-
301-
function Rosenbrock32ConstantCache(
302-
::Type{T}, tf, uf, J, W, linsolve, autodiff
303-
) where {T}
304-
tab = Rosenbrock32Tableau(T)
305-
return Rosenbrock32ConstantCache(
306-
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff
307-
)
308-
end
309-
310-
function alg_cache(
311-
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
312-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
313-
dt, reltol, p, calck,
314-
::Val{false}, verbose
315-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
316-
tf = TimeDerivativeWrapper(f, u, p)
317-
uf = UDerivativeWrapper(f, t, p)
318-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
319-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
320-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
321-
return Rosenbrock32ConstantCache(
322-
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
323-
alg_autodiff(alg)
324-
)
325-
end
326-
32765
### Rodas4+ methods and consolidated Rosenbrock methods (using RodasTableau)
32866

32967
# Helper accessors for step_limiter!/stage_limiter! — algorithms that have these fields
@@ -376,6 +114,8 @@ tabtype(::GRK4T) = GRK4TRodasTableau
376114
tabtype(::GRK4A) = GRK4ARodasTableau
377115
tabtype(::Ros4LStab) = Ros4LStabRodasTableau
378116
tabtype(::RosenbrockW6S4OS) = RosenbrockW6S4OSRodasTableau
117+
tabtype(::Rosenbrock23) = Rosenbrock23RodasTableau
118+
tabtype(::Rosenbrock32) = Rosenbrock32RodasTableau
379119

380120
# Union of all algorithms using RodasTableau-based RosenbrockCache
381121
const RodasTableauAlgorithms = Union{
@@ -387,6 +127,7 @@ const RodasTableauAlgorithms = Union{
387127
ROS34PRw, ROS3PRL, ROS3PRL2, ROK4a,
388128
RosShamp4, Veldd4, Velds4, GRK4T, GRK4A, Ros4LStab,
389129
RosenbrockW6S4OS,
130+
Rosenbrock23, Rosenbrock32,
390131
}
391132

392133
function alg_cache(
@@ -493,10 +234,7 @@ function alg_cache(
493234
end
494235

495236
function get_fsalfirstlast(
496-
cache::Union{
497-
Rosenbrock23Cache, Rosenbrock32Cache,
498-
RosenbrockCache,
499-
},
237+
cache::RosenbrockCache,
500238
u
501239
)
502240
return (cache.fsalfirst, cache.fsallast)

0 commit comments

Comments
 (0)