Skip to content

Commit e7881f7

Browse files
Harsh Singhclaude
authored andcommitted
refactor(Rosenbrock): migrate Rosenbrock23/32 to generic RodasTableau framework
Unify Rosenbrock23 and Rosenbrock32 with the existing generic Rosenbrock/Rodas infrastructure by expressing their coefficients as RodasTableau entries. This removes ~900 lines of duplicated cache structs, perform_step!, interpolation, and addsteps code. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0d38107 commit e7881f7

File tree

6 files changed

+89
-934
lines changed

6 files changed

+89
-934
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,3 @@
1-
function SciMLBase.interp_summary(
2-
::Type{cacheType},
3-
dense::Bool
4-
) where {
5-
cacheType <:
6-
Union{
7-
Rosenbrock23ConstantCache,
8-
Rosenbrock32ConstantCache,
9-
Rosenbrock23Cache,
10-
Rosenbrock32Cache,
11-
},
12-
}
13-
return dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
14-
"1st order linear"
15-
end
161
function SciMLBase.interp_summary(
172
::Type{cacheType},
183
dense::Bool

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 4 additions & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -64,268 +64,6 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <:
6464
interp_order::Int
6565
end
6666

67-
@cache mutable struct Rosenbrock23Cache{
68-
uType, rateType, uNoUnitsType, JType, WType,
69-
TabType, TFType, UFType, F, JCType, GCType,
70-
RTolType, A, AV, StepLimiter, StageLimiter,
71-
} <: RosenbrockMutableCache
72-
u::uType
73-
uprev::uType
74-
k₁::rateType
75-
k₂::rateType
76-
k₃::rateType
77-
du1::rateType
78-
du2::rateType
79-
f₁::rateType
80-
fsalfirst::rateType
81-
fsallast::rateType
82-
dT::rateType
83-
J::JType
84-
W::WType
85-
tmp::rateType
86-
atmp::uNoUnitsType
87-
weight::uNoUnitsType
88-
tab::TabType
89-
tf::TFType
90-
uf::UFType
91-
linsolve_tmp::rateType
92-
linsolve::F
93-
jac_config::JCType
94-
grad_config::GCType
95-
reltol::RTolType
96-
alg::A
97-
algebraic_vars::AV
98-
step_limiter!::StepLimiter
99-
stage_limiter!::StageLimiter
100-
end
101-
102-
@cache mutable struct Rosenbrock32Cache{
103-
uType, rateType, uNoUnitsType, JType, WType,
104-
TabType, TFType, UFType, F, JCType, GCType,
105-
RTolType, A, AV, StepLimiter, StageLimiter,
106-
} <: RosenbrockMutableCache
107-
u::uType
108-
uprev::uType
109-
k₁::rateType
110-
k₂::rateType
111-
k₃::rateType
112-
du1::rateType
113-
du2::rateType
114-
f₁::rateType
115-
fsalfirst::rateType
116-
fsallast::rateType
117-
dT::rateType
118-
J::JType
119-
W::WType
120-
tmp::rateType
121-
atmp::uNoUnitsType
122-
weight::uNoUnitsType
123-
tab::TabType
124-
tf::TFType
125-
uf::UFType
126-
linsolve_tmp::rateType
127-
linsolve::F
128-
jac_config::JCType
129-
grad_config::GCType
130-
reltol::RTolType
131-
alg::A
132-
algebraic_vars::AV
133-
step_limiter!::StepLimiter
134-
stage_limiter!::StageLimiter
135-
end
136-
137-
function alg_cache(
138-
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
139-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
140-
dt, reltol, p, calck,
141-
::Val{true}, verbose
142-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
143-
k₁ = zero(rate_prototype)
144-
k₂ = zero(rate_prototype)
145-
k₃ = zero(rate_prototype)
146-
du1 = zero(rate_prototype)
147-
du2 = zero(rate_prototype)
148-
# f₀ = zero(u) fsalfirst
149-
f₁ = zero(rate_prototype)
150-
fsalfirst = zero(rate_prototype)
151-
fsallast = zero(rate_prototype)
152-
dT = zero(rate_prototype)
153-
tmp = zero(rate_prototype)
154-
atmp = similar(u, uEltypeNoUnits)
155-
recursivefill!(atmp, false)
156-
weight = similar(u, uEltypeNoUnits)
157-
recursivefill!(weight, false)
158-
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
159-
tf = TimeGradientWrapper(f, uprev, p)
160-
uf = UJacobianWrapper(f, t, p)
161-
linsolve_tmp = zero(rate_prototype)
162-
163-
grad_config = build_grad_config(alg, f, tf, du1, t)
164-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
165-
166-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
167-
168-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
169-
Pl,
170-
Pr = wrapprecs(
171-
alg.precs(
172-
W, nothing, u, p, t, nothing, nothing, nothing,
173-
nothing
174-
)..., weight, tmp
175-
)
176-
linsolve = init(
177-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
178-
Pl = Pl, Pr = Pr,
179-
assumptions = LinearSolve.OperatorAssumptions(true),
180-
verbose = verbose.linear_verbosity
181-
)
182-
183-
algebraic_vars = f.mass_matrix === I ? nothing :
184-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
185-
186-
return Rosenbrock23Cache(
187-
u, uprev, k₁, k₂, k₃, du1, du2, f₁,
188-
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
189-
linsolve_tmp,
190-
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
191-
alg.stage_limiter!
192-
)
193-
end
194-
195-
function alg_cache(
196-
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
197-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
198-
dt, reltol, p, calck,
199-
::Val{true}, verbose
200-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
201-
k₁ = zero(rate_prototype)
202-
k₂ = zero(rate_prototype)
203-
k₃ = zero(rate_prototype)
204-
du1 = zero(rate_prototype)
205-
du2 = zero(rate_prototype)
206-
# f₀ = zero(u) fsalfirst
207-
f₁ = zero(rate_prototype)
208-
fsalfirst = zero(rate_prototype)
209-
fsallast = zero(rate_prototype)
210-
dT = zero(rate_prototype)
211-
tmp = zero(rate_prototype)
212-
atmp = similar(u, uEltypeNoUnits)
213-
recursivefill!(atmp, false)
214-
weight = similar(u, uEltypeNoUnits)
215-
recursivefill!(weight, false)
216-
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))
217-
218-
tf = TimeGradientWrapper(f, uprev, p)
219-
uf = UJacobianWrapper(f, t, p)
220-
linsolve_tmp = zero(rate_prototype)
221-
222-
grad_config = build_grad_config(alg, f, tf, du1, t)
223-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
224-
225-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
226-
227-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
228-
229-
Pl,
230-
Pr = wrapprecs(
231-
alg.precs(
232-
W, nothing, u, p, t, nothing, nothing, nothing,
233-
nothing
234-
)..., weight, tmp
235-
)
236-
linsolve = init(
237-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
238-
Pl = Pl, Pr = Pr,
239-
assumptions = LinearSolve.OperatorAssumptions(true),
240-
verbose = verbose.linear_verbosity
241-
)
242-
243-
algebraic_vars = f.mass_matrix === I ? nothing :
244-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
245-
246-
return Rosenbrock32Cache(
247-
u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
248-
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
249-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!
250-
)
251-
end
252-
253-
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
254-
RosenbrockConstantCache
255-
c₃₂::T
256-
d::T
257-
tf::TF
258-
uf::UF
259-
J::JType
260-
W::WType
261-
linsolve::F
262-
autodiff::AD
263-
end
264-
265-
function Rosenbrock23ConstantCache(
266-
::Type{T}, tf, uf, J, W, linsolve, autodiff
267-
) where {T}
268-
tab = Rosenbrock23Tableau(T)
269-
return Rosenbrock23ConstantCache(
270-
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff
271-
)
272-
end
273-
274-
function alg_cache(
275-
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
276-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
277-
dt, reltol, p, calck,
278-
::Val{false}, verbose
279-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
280-
tf = TimeDerivativeWrapper(f, u, p)
281-
uf = UDerivativeWrapper(f, t, p)
282-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
283-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
284-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
285-
return Rosenbrock23ConstantCache(
286-
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
287-
alg_autodiff(alg)
288-
)
289-
end
290-
291-
struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <:
292-
RosenbrockConstantCache
293-
c₃₂::T
294-
d::T
295-
tf::TF
296-
uf::UF
297-
J::JType
298-
W::WType
299-
linsolve::F
300-
autodiff::AD
301-
end
302-
303-
function Rosenbrock32ConstantCache(
304-
::Type{T}, tf, uf, J, W, linsolve, autodiff
305-
) where {T}
306-
tab = Rosenbrock32Tableau(T)
307-
return Rosenbrock32ConstantCache(
308-
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff
309-
)
310-
end
311-
312-
function alg_cache(
313-
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
314-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
315-
dt, reltol, p, calck,
316-
::Val{false}, verbose
317-
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
318-
tf = TimeDerivativeWrapper(f, u, p)
319-
uf = UDerivativeWrapper(f, t, p)
320-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
321-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
322-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
323-
return Rosenbrock32ConstantCache(
324-
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
325-
alg_autodiff(alg)
326-
)
327-
end
328-
32967
### Rodas4+ methods and consolidated Rosenbrock methods (using RodasTableau)
33068

33169
# Helper accessors for step_limiter!/stage_limiter! — algorithms that have these fields
@@ -378,6 +116,8 @@ tabtype(::GRK4T) = GRK4TRodasTableau
378116
tabtype(::GRK4A) = GRK4ARodasTableau
379117
tabtype(::Ros4LStab) = Ros4LStabRodasTableau
380118
tabtype(::RosenbrockW6S4OS) = RosenbrockW6S4OSRodasTableau
119+
tabtype(::Rosenbrock23) = Rosenbrock23RodasTableau
120+
tabtype(::Rosenbrock32) = Rosenbrock32RodasTableau
381121

382122
# Union of all algorithms using RodasTableau-based RosenbrockCache
383123
const RodasTableauAlgorithms = Union{
@@ -389,6 +129,7 @@ const RodasTableauAlgorithms = Union{
389129
ROS34PRw, ROS3PRL, ROS3PRL2, ROK4a,
390130
RosShamp4, Veldd4, Velds4, GRK4T, GRK4A, Ros4LStab,
391131
RosenbrockW6S4OS,
132+
Rosenbrock23, Rosenbrock32,
392133
}
393134

394135
function alg_cache(
@@ -499,10 +240,7 @@ function alg_cache(
499240
end
500241

501242
function get_fsalfirstlast(
502-
cache::Union{
503-
Rosenbrock23Cache, Rosenbrock32Cache,
504-
RosenbrockCache,
505-
},
243+
cache::RosenbrockCache,
506244
u
507245
)
508246
return (cache.fsalfirst, cache.fsallast)

0 commit comments

Comments
 (0)