Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
function SciMLBase.interp_summary(
::Type{cacheType},
dense::Bool
) where {
cacheType <:
Union{
Rosenbrock23ConstantCache,
Rosenbrock32ConstantCache,
Rosenbrock23Cache,
Rosenbrock32Cache,
},
}
return dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those cache types (Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock23ConstantCache, Rosenbrock32ConstantCache) no longer exist — they've been replaced by the generic RosenbrockCache and RosenbrockConstantCache. So this method would never dispatch. The generic Rosenbrock interp_summary already covers RosenbrockCache/RosenbrockConstantCache.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then lower, the interpolation summary needs to get improved so it's based on the cache type's information in order to be correct? It should probably read the k-length in order to determine the order to share it, or have some other way to know it.

"1st order linear"
end
function SciMLBase.interp_summary(
::Type{cacheType},
dense::Bool
Expand Down Expand Up @@ -42,3 +27,31 @@ function SciMLBase.interp_summary(
"specialized 4th (Rodas6P = 5th) order \"free\" stiffness-aware interpolation" :
"1st order linear"
end

# strip_cache for RosenbrockCache: the generic OrdinaryDiffEqCore version passes
# all-Nothing args, but several fields (dense, dtC, dtd, ks, interp_order, jac_reuse) have
# concrete types that don't accept Nothing. Provide a custom override that clears
# only the interpolation-related fields to zero-length arrays / nothing.
function OrdinaryDiffEqCore.strip_cache(cache::RosenbrockCache)
return RosenbrockCache(
cache.u, cache.uprev,
similar(cache.dense, 0), # dense::Vector{rateType} — must not be Nothing
cache.du, cache.du1, cache.du2,
similar(cache.dtC, 0, 0), # dtC::Matrix{tabType} — must not be Nothing
similar(cache.dtd, 0), # dtd::Vector{tabType} — must not be Nothing
similar(cache.ks, 0), # ks::Vector{rateType} — must not be Nothing
cache.fsalfirst, cache.fsallast, cache.dT,
cache.J, cache.W,
cache.tmp, cache.atmp, cache.weight,
cache.tab,
nothing, # tf
nothing, # uf
cache.linsolve_tmp, cache.linsolve,
nothing, # jac_config
nothing, # grad_config
cache.reltol, cache.alg,
cache.step_limiter!, cache.stage_limiter!,
cache.interp_order, # interp_order::Int — must not be Nothing
cache.jac_reuse # jac_reuse — preserve state across strip
)
end
279 changes: 4 additions & 275 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,277 +115,6 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD, JRType}
jac_reuse::JRType
end

@cache mutable struct Rosenbrock23Cache{
uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter, StageLimiter, JRType,
} <: RosenbrockMutableCache
u::uType
uprev::uType
k₁::rateType
k₂::rateType
k₃::rateType
du1::rateType
du2::rateType
f₁::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
jac_reuse::JRType
end

@cache mutable struct Rosenbrock32Cache{
uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter, StageLimiter, JRType,
} <: RosenbrockMutableCache
u::uType
uprev::uType
k₁::rateType
k₂::rateType
k₃::rateType
du1::rateType
du2::rateType
f₁::rateType
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
jac_reuse::JRType
end

function alg_cache(
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
k₁ = zero(rate_prototype)
k₂ = zero(rate_prototype)
k₃ = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
# f₀ = zero(u) fsalfirst
f₁ = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)

J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl,
Pr = wrapprecs(
alg.precs(
W, nothing, u, p, t, nothing, nothing, nothing,
nothing
)..., weight, tmp
)
linsolve = init(
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
Pl = Pl, Pr = Pr,
abstol = reltol, reltol = reltol,
assumptions = LinearSolve.OperatorAssumptions(true),
verbose = verbose.linear_verbosity
)

algebraic_vars = f.mass_matrix === I ? nothing :
[all(iszero, x) for x in eachcol(f.mass_matrix)]

return Rosenbrock23Cache(
u, uprev, k₁, k₂, k₃, du1, du2, f₁,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
alg.stage_limiter!, _make_jac_reuse_state(zero(dt), alg.max_jac_age)
)
end

function alg_cache(
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
k₁ = zero(rate_prototype)
k₂ = zero(rate_prototype)
k₃ = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
# f₀ = zero(u) fsalfirst
f₁ = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)

J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))

Pl,
Pr = wrapprecs(
alg.precs(
W, nothing, u, p, t, nothing, nothing, nothing,
nothing
)..., weight, tmp
)
linsolve = init(
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
Pl = Pl, Pr = Pr,
abstol = reltol, reltol = reltol,
assumptions = LinearSolve.OperatorAssumptions(true),
verbose = verbose.linear_verbosity
)

algebraic_vars = f.mass_matrix === I ? nothing :
[all(iszero, x) for x in eachcol(f.mass_matrix)]

return Rosenbrock32Cache(
u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!,
_make_jac_reuse_state(zero(dt), alg.max_jac_age)
)
end

struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD, JRType} <:
RosenbrockConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
autodiff::AD
jac_reuse::JRType
end

function Rosenbrock23ConstantCache(
::Type{T}, tf, uf, J, W, linsolve, autodiff, max_jac_age::Int = 20
) where {T}
tab = Rosenbrock23Tableau(T)
return Rosenbrock23ConstantCache(
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff,
_make_jac_reuse_state(zero(T), max_jac_age)
)
end

function alg_cache(
alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
return Rosenbrock23ConstantCache(
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
alg_autodiff(alg), alg.max_jac_age
)
end

struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD, JRType} <:
RosenbrockConstantCache
c₃₂::T
d::T
tf::TF
uf::UF
J::JType
W::WType
linsolve::F
autodiff::AD
jac_reuse::JRType
end

function Rosenbrock32ConstantCache(
::Type{T}, tf, uf, J, W, linsolve, autodiff, max_jac_age::Int = 20
) where {T}
tab = Rosenbrock32Tableau(T)
return Rosenbrock32ConstantCache(
tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff,
_make_jac_reuse_state(zero(T), max_jac_age)
)
end

function alg_cache(
alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
return Rosenbrock32ConstantCache(
constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
alg_autodiff(alg), alg.max_jac_age
)
end

### Rodas4+ methods and consolidated Rosenbrock methods (using RodasTableau)

# Helper accessors for step_limiter!/stage_limiter! — algorithms that have these fields
Expand Down Expand Up @@ -438,6 +167,8 @@ tabtype(::GRK4T) = GRK4TRodasTableau
tabtype(::GRK4A) = GRK4ARodasTableau
tabtype(::Ros4LStab) = Ros4LStabRodasTableau
tabtype(::RosenbrockW6S4OS) = RosenbrockW6S4OSRodasTableau
tabtype(::Rosenbrock23) = Rosenbrock23RodasTableau
tabtype(::Rosenbrock32) = Rosenbrock32RodasTableau

# Union of all algorithms using RodasTableau-based RosenbrockCache
const RodasTableauAlgorithms = Union{
Expand All @@ -449,6 +180,7 @@ const RodasTableauAlgorithms = Union{
ROS34PRw, ROS3PRL, ROS3PRL2, ROK4a,
RosShamp4, Veldd4, Velds4, GRK4T, GRK4A, Ros4LStab,
RosenbrockW6S4OS,
Rosenbrock23, Rosenbrock32,
}

function alg_cache(
Expand Down Expand Up @@ -562,10 +294,7 @@ function alg_cache(
end

function get_fsalfirstlast(
cache::Union{
Rosenbrock23Cache, Rosenbrock32Cache,
RosenbrockCache,
},
cache::RosenbrockCache,
u
)
return (cache.fsalfirst, cache.fsallast)
Expand Down
Loading
Loading