diff --git a/lib/OrdinaryDiffEqCore/src/algorithms.jl b/lib/OrdinaryDiffEqCore/src/algorithms.jl index 18c9fd25aa0..1313c0a45bb 100644 --- a/lib/OrdinaryDiffEqCore/src/algorithms.jl +++ b/lib/OrdinaryDiffEqCore/src/algorithms.jl @@ -20,6 +20,8 @@ abstract type OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqAdaptiveAlgorithm end abstract type OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD, FDT, ST, CJ} end +abstract type OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} <: +OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} end abstract type OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD, FDT, ST, CJ} end diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index 37eb64be2c2..45f6a03e2bb 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -6,6 +6,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, OrdinaryDiffEqAlgorithm, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, OrdinaryDiffEqNewtonAdaptiveAlgorithm, + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm, OrdinaryDiffEqNewtonAlgorithm, DEFAULT_PRECS, OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev, @@ -37,13 +38,16 @@ include("kencarp_kvaerno_caches.jl") include("sdirk_perform_step.jl") include("kencarp_kvaerno_perform_step.jl") include("sdirk_tableaus.jl") +include("imex_tableaus.jl") +include("generic_imex_perform_step.jl") export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4, Kvaerno5, KenCarp4, KenCarp47, KenCarp5, KenCarp58, ESDIRK54I8L2SA, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, - SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA + SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA, + ARS343 import PrecompileTools import Preferences diff --git a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl index c6903a76dee..46b7206950b 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl @@ -57,3 +57,6 @@ issplit(alg::KenCarp47) = true issplit(alg::KenCarp5) = true issplit(alg::KenCarp58) = true issplit(alg::CFNLIRK3) = true +issplit(alg::ARS343) = true +alg_order(alg::ARS343) = 3 +isesdirk(alg::ARS343) = true diff --git a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl index 56d80a25fb2..acee6e478a5 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl @@ -79,6 +79,7 @@ function SDIRK_docstring( ) end + @doc SDIRK_docstring( "A 1st order implicit solver. A-B-L-stable. Adaptive timestepping through a divided differences estimate. Strong-stability preserving (SSP). Good for highly stiff equations.", "ImplicitEuler"; @@ -486,7 +487,7 @@ end """ ) struct Kvaerno3{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -538,7 +539,7 @@ end """ ) struct KenCarp3{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1045,7 +1046,7 @@ end """ ) struct Kvaerno4{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1101,7 +1102,7 @@ end """ ) struct Kvaerno5{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1153,7 +1154,7 @@ end """ ) struct KenCarp4{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1207,14 +1208,15 @@ end controller = :PI, """ ) -struct KenCarp47{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct KenCarp47{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function KenCarp47(; @@ -1223,17 +1225,17 @@ function KenCarp47(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), smooth_est = true, extrapolant = :linear, - controller = :PI + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return KenCarp47{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( linsolve, nlsolve, precs, smooth_est, extrapolant, - controller, AD_choice + controller, step_limiter!, AD_choice ) end @@ -1259,7 +1261,7 @@ end """ ) struct KenCarp5{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1311,14 +1313,15 @@ end controller = :PI, """ ) -struct KenCarp58{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct KenCarp58{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function KenCarp58(; @@ -1327,17 +1330,17 @@ function KenCarp58(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), smooth_est = true, extrapolant = :linear, - controller = :PI + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return KenCarp58{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( linsolve, nlsolve, precs, smooth_est, extrapolant, - controller, AD_choice + controller, step_limiter!, AD_choice ) end @@ -1364,13 +1367,15 @@ but are still being fully evaluated in context.", controller = :PI, """ ) -struct ESDIRK54I8L2SA{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct ESDIRK54I8L2SA{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function ESDIRK54I8L2SA(; @@ -1378,17 +1383,18 @@ function ESDIRK54I8L2SA(; standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, controller = :PI + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return ESDIRK54I8L2SA{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( - linsolve, nlsolve, precs, extrapolant, - controller, AD_choice + linsolve, nlsolve, precs, smooth_est, extrapolant, + controller, step_limiter!, AD_choice ) end @@ -1414,13 +1420,15 @@ but are still being fully evaluated in context.", controller = :PI, """ ) -struct ESDIRK436L2SA2{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct ESDIRK436L2SA2{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function ESDIRK436L2SA2(; @@ -1428,17 +1436,18 @@ function ESDIRK436L2SA2(; standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, controller = :PI + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return ESDIRK436L2SA2{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( - linsolve, nlsolve, precs, extrapolant, - controller, AD_choice + linsolve, nlsolve, precs, smooth_est, extrapolant, + controller, step_limiter!, AD_choice ) end @@ -1464,13 +1473,15 @@ but are still being fully evaluated in context.", controller = :PI, """ ) -struct ESDIRK437L2SA{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct ESDIRK437L2SA{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function ESDIRK437L2SA(; @@ -1478,17 +1489,18 @@ function ESDIRK437L2SA(; standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, controller = :PI + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return ESDIRK437L2SA{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( - linsolve, nlsolve, precs, extrapolant, - controller, AD_choice + linsolve, nlsolve, precs, smooth_est, extrapolant, + controller, step_limiter!, AD_choice ) end @@ -1514,13 +1526,15 @@ but are still being fully evaluated in context.", controller = :PI, """ ) -struct ESDIRK547L2SA2{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct ESDIRK547L2SA2{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function ESDIRK547L2SA2(; @@ -1528,17 +1542,18 @@ function ESDIRK547L2SA2(; standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, controller = :PI + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return ESDIRK547L2SA2{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( - linsolve, nlsolve, precs, extrapolant, - controller, AD_choice + linsolve, nlsolve, precs, smooth_est, extrapolant, + controller, step_limiter!, AD_choice ) end @@ -1566,13 +1581,15 @@ Check issue https://github.com/SciML/OrdinaryDiffEq.jl/issues/1933 for more deta controller = :PI, """ ) -struct ESDIRK659L2SA{CS, AD, F, F2, P, FDT, ST, CJ} <: - OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ} +struct ESDIRK659L2SA{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol + step_limiter!::StepLimiter autodiff::AD end function ESDIRK659L2SA(; @@ -1580,16 +1597,73 @@ function ESDIRK659L2SA(; standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, controller = :PI + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) return ESDIRK659L2SA{ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), - _unwrap_val(concrete_jac), + _unwrap_val(concrete_jac), typeof(step_limiter!), }( - linsolve, nlsolve, precs, extrapolant, - controller, AD_choice + linsolve, nlsolve, precs, smooth_est, extrapolant, + controller, step_limiter!, AD_choice + ) +end + +@doc SDIRK_docstring( + "3rd order L-stable IMEX ARK method. Uses a generic tableau-driven implementation that supports both split and non-split forms.", + "ARS343"; + references = "@article{ascher1997implicit, + title={Implicit-explicit Runge-Kutta methods for time-dependent partial differential equations}, + author={Ascher, Uri M and Ruuth, Steven J and Spiteri, Raymond J}, + journal={Applied Numerical Mathematics}, + volume={25}, + number={2-3}, + pages={151--167}, + year={1997}, + publisher={Elsevier}}", + extra_keyword_description = """ + - `smooth_est`: TBD + - `extrapolant`: TBD + - `controller`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + smooth_est = true, + extrapolant = :linear, + controller = :PI, + step_limiter! = trivial_limiter!, + """ +) +struct ARS343{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + smooth_est::Bool + extrapolant::Symbol + controller::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function ARS343(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + smooth_est = true, extrapolant = :linear, + controller = :PI, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + + return ARS343{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }( + linsolve, nlsolve, precs, + smooth_est, extrapolant, controller, step_limiter!, AD_choice ) end diff --git a/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl new file mode 100644 index 00000000000..ec4f8267aeb --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl @@ -0,0 +1,326 @@ +mutable struct ESDIRKIMEXConstantCache{Tab, N} <: OrdinaryDiffEqConstantCache + nlsolver::N + tab::Tab +end + +mutable struct ESDIRKIMEXCache{uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + zs::Vector{uType} + ks::Vector{kType} + atmp::uNoUnitsType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function full_cache(c::ESDIRKIMEXCache) + base = (c.u, c.uprev, c.fsalfirst, c.zs..., c.atmp) + if eltype(c.ks) !== Nothing + return tuple(base..., c.ks...) + end + return base +end + +function alg_cache( + alg::OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ = tab.Ai[2, 2] + c = tab.c[2] + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return ESDIRKIMEXConstantCache(nlsolver, tab) +end + +function alg_cache( + alg::OrdinaryDiffEqNewtonAdaptiveESDIRKAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, + ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ = tab.Ai[2, 2] + c = tab.c[2] + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + + s = tab.s + if f isa SplitFunction + ks = [zero(u) for _ in 1:s] + else + ks = Vector{Nothing}(nothing, s) + end + + zs = [zero(u) for _ in 1:(s - 1)] + push!(zs, nlsolver.z) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + + return ESDIRKIMEXCache( + u, uprev, fsalfirst, zs, ks, atmp, nlsolver, tab, alg.step_limiter! + ) +end + +function initialize!(integrator, cache::ESDIRKIMEXConstantCache) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + return nothing +end + +function initialize!(integrator, cache::ESDIRKIMEXCache) + integrator.kshortsize = 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + return nothing +end + +@muladd function perform_step!( + integrator, cache::ESDIRKIMEXConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + tab = cache.tab + (; Ai, bi, Ae, be, c, btilde, ebtilde, α, s) = tab + alg = unwrap_alg(integrator, true) + γ = Ai[2, 2] + + f2 = nothing + k = Vector{typeof(u)}(undef, s) + if integrator.f isa SplitFunction + f_impl = integrator.f.f1 + f2 = integrator.f.f2 + else + f_impl = integrator.f + end + + markfirststage!(nlsolver) + + z = Vector{typeof(u)}(undef, s) + + if integrator.f isa SplitFunction + z[1] = dt * f_impl(uprev, p, t) + else + z[1] = dt * integrator.fsalfirst + end + + if integrator.f isa SplitFunction + k[1] = dt * integrator.fsalfirst - z[1] + end + + for i in 2:s + tmp = uprev + for j in 1:(i - 1) + tmp = tmp + Ai[i, j] * z[j] + end + + if integrator.f isa SplitFunction + for j in 1:(i - 1) + tmp = tmp + Ae[i, j] * k[j] + end + end + + if integrator.f isa SplitFunction + z_guess = z[1] + elseif α !== nothing && !iszero(α[i, 1]) + z_guess = zero(u) + for j in 1:(i - 1) + z_guess = z_guess + α[i, j] * z[j] + end + else + z_guess = zero(u) + end + + nlsolver.z = z_guess + nlsolver.tmp = tmp + nlsolver.c = c[i] + nlsolver.γ = γ + z[i] = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction && i < s + u_stage = tmp + γ * z[i] + k[i] = dt * f2(u_stage, p, t + c[i] * dt) + integrator.stats.nf2 += 1 + end + end + + u = nlsolver.tmp + γ * z[s] + if integrator.f isa SplitFunction + k[s] = dt * f2(u, p, t + dt) + integrator.stats.nf2 += 1 + u = uprev + for i in 1:s + u = u + bi[i] * z[i] + be[i] * k[i] + end + end + + if integrator.opts.adaptive && btilde !== nothing + tmp = zero(u) + for i in 1:s + tmp = tmp + btilde[i] * z[i] + end + if integrator.f isa SplitFunction && ebtilde !== nothing + for i in 1:s + tmp = tmp + ebtilde[i] * k[i] + end + end + if isnewton(nlsolver) && alg.smooth_est + integrator.stats.nsolve += 1 + est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) + else + est = tmp + end + atmp = calculate_residuals( + est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z[s] ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::ESDIRKIMEXCache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; zs, ks, atmp, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + tab = cache.tab + (; Ai, bi, Ae, be, c, btilde, ebtilde, α, s) = tab + alg = unwrap_alg(integrator, true) + γ = Ai[2, 2] + + f2 = nothing + if integrator.f isa SplitFunction + f_impl = integrator.f.f1 + f2 = integrator.f.f2 + else + f_impl = integrator.f + end + + markfirststage!(nlsolver) + + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f_impl(zs[1], integrator.uprev, p, integrator.t) + zs[1] .*= dt + else + @..zs[1] = dt * integrator.fsalfirst + end + + if integrator.f isa SplitFunction + @..ks[1] = dt * integrator.fsalfirst - zs[1] + end + + for i in 2:s + @..tmp = uprev + for j in 1:(i - 1) + @..tmp += Ai[i, j] * zs[j] + end + + if integrator.f isa SplitFunction + for j in 1:(i - 1) + @..tmp += Ae[i, j] * ks[j] + end + end + + if integrator.f isa SplitFunction + copyto!(zs[i], zs[1]) + elseif α !== nothing && !iszero(α[i, 1]) + fill!(zs[i], zero(eltype(u))) + for j in 1:(i - 1) + @..zs[i] += α[i, j] * zs[j] + end + else + fill!(zs[i], zero(eltype(u))) + end + + nlsolver.z = zs[i] + nlsolver.c = c[i] + nlsolver.γ = γ + zs[i] = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + if i > 2 + isnewton(nlsolver) && set_new_W!(nlsolver, false) + end + + if integrator.f isa SplitFunction && i < s + @..u = tmp + γ * zs[i] + f2(ks[i], u, p, t + c[i] * dt) + ks[i] .*= dt + integrator.stats.nf2 += 1 + end + end + + @..u = tmp + γ * zs[s] + if integrator.f isa SplitFunction + f2(ks[s], u, p, t + dt) + ks[s] .*= dt + integrator.stats.nf2 += 1 + @..u = uprev + for i in 1:s + @..u += bi[i] * zs[i] + be[i] * ks[i] + end + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.opts.adaptive && btilde !== nothing + @..tmp = zero(eltype(u)) + for i in 1:s + @..tmp += btilde[i] * zs[i] + end + if integrator.f isa SplitFunction && ebtilde !== nothing + for i in 1:s + @..tmp += ebtilde[i] * ks[i] + end + end + if isnewton(nlsolver) && alg.smooth_est + est = nlsolver.cache.dz + linres = dolinsolve( + integrator, nlsolver.cache.linsolve; b = _vec(tmp), + linu = _vec(est) + ) + integrator.stats.nsolve += 1 + else + est = tmp + end + calculate_residuals!( + atmp, est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @..integrator.fsallast = zs[s] / dt + end +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl new file mode 100644 index 00000000000..a204c77718b --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl @@ -0,0 +1,1760 @@ +struct ESDIRKIMEXTableau{T, T2} + Ai::Matrix{T} + bi::Vector{T} + Ae::Matrix{T} + be::Vector{T} + c::Vector{T2} + btilde::Union{Vector{T}, Nothing} + ebtilde::Union{Vector{T}, Nothing} + α::Union{Matrix{T2}, Nothing} + order::Int + s::Int +end + +# Dispatch: each algorithm type maps to its tableau constructor +ESDIRKIMEXTableau(::ARS343, T, T2) = ARS343Tableau(T, T2) +ESDIRKIMEXTableau(::KenCarp3, T, T2) = KenCarp3ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::Kvaerno3, T, T2) = Kvaerno3ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::Kvaerno4, T, T2) = Kvaerno4ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::Kvaerno5, T, T2) = Kvaerno5ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::KenCarp4, T, T2) = KenCarp4ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::KenCarp5, T, T2) = KenCarp5ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::KenCarp47, T, T2) = KenCarp47ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::KenCarp58, T, T2) = KenCarp58ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ESDIRK54I8L2SA, T, T2) = ESDIRK54I8L2SAESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ESDIRK436L2SA2, T, T2) = ESDIRK436L2SA2ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ESDIRK437L2SA, T, T2) = ESDIRK437L2SAESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ESDIRK547L2SA2, T, T2) = ESDIRK547L2SA2ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ESDIRK659L2SA, T, T2) = ESDIRK659L2SAESDIRKIMEXTableau(T, T2) + +# +# KenCarp3 IMEX Tableau +# + +function KenCarp3ESDIRKIMEXTableau(T::Type{<:CompiledFloats}, T2::Type{<:CompiledFloats}) + γ = convert(T, 0.435866521508459) + + a31 = convert(T, 0.2576482460664272) + a32 = -convert(T, 0.09351476757488625) + a41 = convert(T, 0.18764102434672383) + a42 = -convert(T, 0.595297473576955) + a43 = convert(T, 0.9717899277217721) + + btilde1 = convert(T, 0.027099261876665316) + btilde2 = convert(T, 0.11013520969201586) + btilde3 = -convert(T, 0.10306492520138458) + btilde4 = -convert(T, 0.0341695463672966) + + c3 = convert(T2, 0.6) + c2 = 2γ + θ = c3 / c2 + α31 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γ) + α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γ) + θ = 1 / c2 + α41 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γ) + α42 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γ) + + ea21 = convert(T, 0.871733043016918) + ea31 = convert(T, 0.5275890119763004) + ea32 = convert(T, 0.0724109880236996) + ea41 = convert(T, 0.3990960076760701) + ea42 = -convert(T, 0.4375576546135194) + ea43 = convert(T, 1.0384616469374492) + eb1 = convert(T, 0.18764102434672383) + eb2 = -convert(T, 0.595297473576955) + eb3 = convert(T, 0.9717899277217721) + eb4 = convert(T, 0.435866521508459) + ebtilde1 = convert(T, 0.027099261876665316) + ebtilde2 = convert(T, 0.11013520969201586) + ebtilde3 = -convert(T, 0.10306492520138458) + ebtilde4 = -convert(T, 0.0341695463672966) + + s = 4 + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = γ + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = γ + + bi_vec = zeros(T, s) + bi_vec[1] = a41 + bi_vec[2] = a42 + bi_vec[3] = a43 + bi_vec[4] = γ + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 2] = ea42 + Ae[4, 3] = ea43 + + be_vec = zeros(T, s) + be_vec[1] = eb1 + be_vec[2] = eb2 + be_vec[3] = eb3 + be_vec[4] = eb4 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[2] = btilde2 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[1] = ebtilde1 + ebtilde_vec[2] = ebtilde2 + ebtilde_vec[3] = ebtilde3 + ebtilde_vec[4] = ebtilde4 + + α_mat = zeros(T2, s, s) + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 3, s + ) +end + +function KenCarp3ESDIRKIMEXTableau(T, T2) + γ = convert(T, 1767732205903 // 4055673282236) + + a31 = convert(T, 2746238789719 // 10658868560708) + a32 = -convert(T, 640167445237 // 6845629431997) + a41 = convert(T, 1471266399579 // 7840856788654) + a42 = -convert(T, 4482444167858 // 7529755066697) + a43 = convert(T, 11266239266428 // 11593286722821) + + btilde1 = convert( + T, + BigInt(681815649026867975666107) // + BigInt(25159934323302256049469295) + ) + btilde2 = convert( + T, + BigInt(18411887981491912264464127) // + BigInt(167175311446532472108584143) + ) + btilde3 = -convert( + T, + BigInt(12719313754959329011138489) // + BigInt(123410692144842870217698057) + ) + btilde4 = -convert( + T, + BigInt(47289384293135913063989) // + BigInt(1383962894467812063558225) + ) + + c3 = convert(T2, 3 // 5) + c2 = 2γ + θ = c3 / c2 + α31 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γ) + α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γ) + θ = 1 / c2 + α41 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * γ) + α42 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γ) + + ea21 = convert(T, 1767732205903 // 2027836641118) + ea31 = convert(T, 5535828885825 // 10492691773637) + ea32 = convert(T, 788022342437 // 10882634858940) + ea41 = convert(T, 6485989280629 // 16251701735622) + ea42 = -convert(T, 4246266847089 // 9704473918619) + ea43 = convert(T, 10755448449292 // 10357097424841) + eb1 = convert(T, 1471266399579 // 7840856788654) + eb2 = convert(T, -4482444167858 // 7529755066697) + eb3 = convert(T, 11266239266428 // 11593286722821) + eb4 = convert(T, 1767732205903 // 4055673282236) + ebtilde1 = convert( + T, + BigInt(681815649026867975666107) // + BigInt(25159934323302256049469295) + ) + ebtilde2 = convert( + T, + BigInt(18411887981491912264464127) // + BigInt(167175311446532472108584143) + ) + ebtilde3 = -convert( + T, + BigInt(12719313754959329011138489) // + BigInt(123410692144842870217698057) + ) + ebtilde4 = -convert( + T, + BigInt(47289384293135913063989) // + BigInt(1383962894467812063558225) + ) + + s = 4 + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = γ + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = γ + + bi_vec = zeros(T, s) + bi_vec[1] = a41 + bi_vec[2] = a42 + bi_vec[3] = a43 + bi_vec[4] = γ + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 2] = ea42 + Ae[4, 3] = ea43 + + be_vec = zeros(T, s) + be_vec[1] = eb1 + be_vec[2] = eb2 + be_vec[3] = eb3 + be_vec[4] = eb4 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[2] = btilde2 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[1] = ebtilde1 + ebtilde_vec[2] = ebtilde2 + ebtilde_vec[3] = ebtilde3 + ebtilde_vec[4] = ebtilde4 + + α_mat = zeros(T2, s, s) + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 3, s + ) +end + +# +# ARS343 Tableau +# + +function ARS343Tableau(T::Type{<:CompiledFloats}, T2::Type{<:CompiledFloats}) + γ = convert(T, 0.435866521508459) + + s = 4 + + c2 = convert(T2, 0.435866521508459) + c3 = convert(T2, 0.7179332607542295) + c4 = one(T2) + + a32_i = convert(T, 0.2820667392457705) + + b3_i = convert(T, -0.644363170684469) + b2_i = convert(T, 1.20849664917601) + + Ai = zeros(T, s, s) + Ai[2, 2] = γ + Ai[3, 2] = a32_i + Ai[3, 3] = γ + Ai[4, 2] = b2_i + Ai[4, 3] = b3_i + Ai[4, 4] = γ + + bi_vec = T[zero(T), b2_i, b3_i, γ] + + ae21 = convert(T, 0.435866521508459) + ae31 = convert(T, 0.321278886) + ae32 = convert(T, 0.3966543748) + ae41 = -convert(T, 0.105858296) + ae42 = convert(T, 0.5529291479) + ae43 = convert(T, 0.5529291479) + + Ae = zeros(T, s, s) + Ae[2, 1] = ae21 + Ae[3, 1] = ae31 + Ae[3, 2] = ae32 + Ae[4, 1] = ae41 + Ae[4, 2] = ae42 + Ae[4, 3] = ae43 + + be_vec = T[zero(T), b2_i, b3_i, γ] + + c_vec = T2[zero(T2), c2, c3, c4] + + α_mat = zeros(T2, s, s) + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + nothing, nothing, α_mat, 3, s + ) +end + +# +# KenCarp4 IMEX Tableau +# + +function KenCarp4ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 1 // 4) + + a31 = convert(T, 8611 // 62500) + a32 = -convert(T, 1743 // 31250) + a41 = convert(T, 5012029 // 34652500) + a42 = -convert(T, 654441 // 2922500) + a43 = convert(T, 174375 // 388108) + a51 = convert(T, 15267082809 // 155376265600) + a52 = -convert(T, 71443401 // 120774400) + a53 = convert(T, 730878875 // 902184768) + a54 = convert(T, 2285395 // 8070912) + a61 = convert(T, 82889 // 524892) + a63 = convert(T, 15625 // 83664) + a64 = convert(T, 69875 // 102672) + a65 = -convert(T, 2260 // 8211) + + btilde1 = convert(T, -31666707 // 9881966720) + btilde3 = convert(T, 256875 // 105007616) + btilde4 = convert(T, 2768025 // 128864768) + btilde5 = -convert(T, 169839 // 3864644) + btilde6 = convert(T, 5247 // 225920) + + c3 = convert(T2, 83 // 250) + c4 = convert(T2, 31 // 50) + c5 = convert(T2, 17 // 20) + + α21 = convert(T2, 2) + α31 = convert(T2, 42 // 125) + α32 = convert(T2, 83 // 125) + α41 = convert(T2, -6 // 25) + α42 = convert(T2, 31 // 25) + α51 = convert(T2, 914470432 // 2064665255) + α52 = convert(T2, 798813 // 724780) + α53 = convert(T2, -824765625 // 372971788) + α54 = convert(T2, 49640 // 29791) + α61 = convert(T2, 288521442795 // 954204491116) + α62 = convert(T2, 2224881 // 2566456) + α63 = convert(T2, -1074821875 // 905317354) + α64 = convert(T2, -3360875 // 8098936) + α65 = convert(T2, 7040 // 4913) + + ea21 = convert(T, 1 // 2) + ea31 = convert(T, 13861 // 62500) + ea32 = convert(T, 6889 // 62500) + ea41 = -convert(T, 116923316275 // 2393684061468) + ea42 = -convert(T, 2731218467317 // 15368042101831) + ea43 = convert(T, 9408046702089 // 11113171139209) + ea51 = -convert(T, 451086348788 // 2902428689909) + ea52 = -convert(T, 2682348792572 // 7519795681897) + ea53 = convert(T, 12662868775082 // 11960479115383) + ea54 = convert(T, 3355817975965 // 11060851509271) + ea61 = convert(T, 647845179188 // 3216320057751) + ea62 = convert(T, 73281519250 // 8382639484533) + ea63 = convert(T, 552539513391 // 3454668386233) + ea64 = convert(T, 3354512671639 // 8306763924573) + ea65 = convert(T, 4040 // 17871) + + eb1 = convert(T, 82889 // 524892) + eb3 = convert(T, 15625 // 83664) + eb4 = convert(T, 69875 // 102672) + eb5 = -convert(T, 2260 // 8211) + eb6 = convert(T, 1 // 4) + + ebtilde1 = -convert(T, 31666707 // 9881966720) + ebtilde3 = convert(T, 256875 // 105007616) + ebtilde4 = convert(T, 2768025 // 128864768) + ebtilde5 = -convert(T, 169839 // 3864644) + ebtilde6 = convert(T, 5247 // 225920) + + s = 6 + Ai = zeros(T, s, s) + Ai[2, 1] = convert(T, γ) + Ai[2, 2] = convert(T, γ) + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = convert(T, γ) + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = convert(T, γ) + Ai[5, 1] = a51 + Ai[5, 2] = a52 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = convert(T, γ) + Ai[6, 1] = a61 + Ai[6, 3] = a63 + Ai[6, 4] = a64 + Ai[6, 5] = a65 + Ai[6, 6] = convert(T, γ) + + bi_vec = zeros(T, s) + bi_vec[1] = a61 + bi_vec[3] = a63 + bi_vec[4] = a64 + bi_vec[5] = a65 + bi_vec[6] = convert(T, γ) + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 2] = ea42 + Ae[4, 3] = ea43 + Ae[5, 1] = ea51 + Ae[5, 2] = ea52 + Ae[5, 3] = ea53 + Ae[5, 4] = ea54 + Ae[6, 1] = ea61 + Ae[6, 2] = ea62 + Ae[6, 3] = ea63 + Ae[6, 4] = ea64 + Ae[6, 5] = ea65 + + be_vec = zeros(T, s) + be_vec[1] = eb1 + be_vec[3] = eb3 + be_vec[4] = eb4 + be_vec[5] = eb5 + be_vec[6] = eb6 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = c5 + c_vec[6] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + btilde_vec[6] = btilde6 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[1] = ebtilde1 + ebtilde_vec[3] = ebtilde3 + ebtilde_vec[4] = ebtilde4 + ebtilde_vec[5] = ebtilde5 + ebtilde_vec[6] = ebtilde6 + + α_mat = zeros(T2, s, s) + α_mat[2, 1] = α21 + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + α_mat[5, 1] = α51 + α_mat[5, 2] = α52 + α_mat[5, 3] = α53 + α_mat[5, 4] = α54 + α_mat[6, 1] = α61 + α_mat[6, 2] = α62 + α_mat[6, 3] = α63 + α_mat[6, 4] = α64 + α_mat[6, 5] = α65 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 4, s + ) +end + +# +# KenCarp5 IMEX Tableau +# + +function KenCarp5ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 41 // 200) + + a31 = convert(T, 41 // 400) + a32 = -convert(T, 567603406766 // 11931857230679) + a41 = convert(T, 683785636431 // 9252920307686) + a43 = -convert(T, 110385047103 // 1367015193373) + a51 = convert(T, 3016520224154 // 10081342136671) + a53 = convert(T, 30586259806659 // 12414158314087) + a54 = -convert(T, 22760509404356 // 11113319521817) + a61 = convert(T, 218866479029 // 1489978393911) + a63 = convert(T, 638256894668 // 5436446318841) + a64 = -convert(T, 1179710474555 // 5321154724896) + a65 = -convert(T, 60928119172 // 8023461067671) + a71 = convert(T, 1020004230633 // 5715676835656) + a73 = convert(T, 25762820946817 // 25263940353407) + a74 = -convert(T, 2161375909145 // 9755907335909) + a75 = -convert(T, 211217309593 // 5846859502534) + a76 = -convert(T, 4269925059573 // 7827059040749) + a81 = -convert(T, 872700587467 // 9133579230613) + a84 = convert(T, 22348218063261 // 9555858737531) + a85 = -convert(T, 1143369518992 // 8141816002931) + a86 = -convert(T, 39379526789629 // 19018526304540) + a87 = convert(T, 32727382324388 // 42900044865799) + + btilde1 = -convert(T, 360431431567533808054934 // 89473089856732078284381229) + btilde4 = convert(T, 21220331609936495351431026 // 309921249937726682547321949) + btilde5 = -convert(T, 42283193605833819490634 // 2144566741190883522084871) + btilde6 = -convert(T, 21843466548811234473856609 // 296589222149359214696574660) + btilde7 = convert(T, 3333910710978735057753642 // 199750492790973993533703797) + btilde8 = convert(T, 45448919757 // 3715198317040) + + c3 = convert(T2, 2935347310677 // 11292855782101) + c4 = convert(T2, 1426016391358 // 7196633302097) + c5 = convert(T2, 92 // 100) + c6 = convert(T2, 24 // 100) + c7 = convert(T2, 3 // 5) + + α31 = convert(T2, 169472355998441 // 463007087066141) + α32 = convert(T2, 293534731067700 // 463007087066141) + α41 = convert(T2, 152460326250177 // 295061965385977) + α42 = convert(T2, 142601639135800 // 295061965385977) + α51 = convert(T2, -51 // 41) + α52 = convert(T2, 92 // 41) + α61 = convert(T2, 17 // 41) + α62 = convert(T2, 24 // 41) + α71 = convert(T2, 13488091065527792 // 122659689776876057) + α72 = convert(T2, -3214953045 // 3673655312) + α73 = convert(T2, 550552676519862000 // 151043064207496529) + α74 = convert(T2, -409689169278408000 // 135215758621947439) + α75 = convert(T2, 3345 // 12167) + α81 = convert(T2, 1490668709762032 // 122659689776876057) + α82 = convert(T2, 5358255075 // 14694621248) + α83 = convert(T2, -229396948549942500 // 151043064207496529) + α84 = convert(T2, 170703820532670000 // 135215758621947439) + α85 = convert(T2, 30275 // 24334) + + ea21 = convert(T, 41 // 100) + ea31 = convert(T, 367902744464 // 2072280473677) + ea32 = convert(T, 677623207551 // 8224143866563) + ea41 = convert(T, 1268023523408 // 10340822734521) + ea43 = convert(T, 1029933939417 // 13636558850479) + ea51 = convert(T, 14463281900351 // 6315353703477) + ea53 = convert(T, 66114435211212 // 5879490589093) + ea54 = -convert(T, 54053170152839 // 4284798021562) + ea61 = convert(T, 14090043504691 // 34967701212078) + ea63 = convert(T, 15191511035443 // 11219624916014) + ea64 = -convert(T, 18461159152457 // 12425892160975) + ea65 = -convert(T, 281667163811 // 9011619295870) + ea71 = convert(T, 19230459214898 // 13134317526959) + ea73 = convert(T, 21275331358303 // 2942455364971) + ea74 = -convert(T, 38145345988419 // 4862620318723) + ea75 = -convert(T, 1 // 8) + ea76 = -convert(T, 1 // 8) + ea81 = -convert(T, 19977161125411 // 11928030595625) + ea83 = -convert(T, 40795976796054 // 6384907823539) + ea84 = convert(T, 177454434618887 // 12078138498510) + ea85 = convert(T, 782672205425 // 8267701900261) + ea86 = -convert(T, 69563011059811 // 9646580694205) + ea87 = convert(T, 7356628210526 // 4942186776405) + + eb1 = -convert(T, 872700587467 // 9133579230613) + eb4 = convert(T, 22348218063261 // 9555858737531) + eb5 = -convert(T, 1143369518992 // 8141816002931) + eb6 = -convert(T, 39379526789629 // 19018526304540) + eb7 = convert(T, 32727382324388 // 42900044865799) + eb8 = convert(T, 41 // 200) + + ebtilde1 = -convert(T, 360431431567533808054934 // 89473089856732078284381229) + ebtilde4 = convert(T, 21220331609936495351431026 // 309921249937726682547321949) + ebtilde5 = -convert(T, 42283193605833819490634 // 2144566741190883522084871) + ebtilde6 = -convert(T, 21843466548811234473856609 // 296589222149359214696574660) + ebtilde7 = convert(T, 3333910710978735057753642 // 199750492790973993533703797) + ebtilde8 = convert(T, 45448919757 // 3715198317040) + + s = 8 + Ai = zeros(T, s, s) + Ai[2, 1] = convert(T, γ) + Ai[2, 2] = convert(T, γ) + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = convert(T, γ) + Ai[4, 1] = a41 + Ai[4, 3] = a43 + Ai[4, 4] = convert(T, γ) + Ai[5, 1] = a51 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = convert(T, γ) + Ai[6, 1] = a61 + Ai[6, 3] = a63 + Ai[6, 4] = a64 + Ai[6, 5] = a65 + Ai[6, 6] = convert(T, γ) + Ai[7, 1] = a71 + Ai[7, 3] = a73 + Ai[7, 4] = a74 + Ai[7, 5] = a75 + Ai[7, 6] = a76 + Ai[7, 7] = convert(T, γ) + Ai[8, 1] = a81 + Ai[8, 4] = a84 + Ai[8, 5] = a85 + Ai[8, 6] = a86 + Ai[8, 7] = a87 + Ai[8, 8] = convert(T, γ) + + bi_vec = zeros(T, s) + bi_vec[1] = a81 + bi_vec[4] = a84 + bi_vec[5] = a85 + bi_vec[6] = a86 + bi_vec[7] = a87 + bi_vec[8] = convert(T, γ) + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 3] = ea43 + Ae[5, 1] = ea51 + Ae[5, 3] = ea53 + Ae[5, 4] = ea54 + Ae[6, 1] = ea61 + Ae[6, 3] = ea63 + Ae[6, 4] = ea64 + Ae[6, 5] = ea65 + Ae[7, 1] = ea71 + Ae[7, 3] = ea73 + Ae[7, 4] = ea74 + Ae[7, 5] = ea75 + Ae[7, 6] = ea76 + Ae[8, 1] = ea81 + Ae[8, 3] = ea83 + Ae[8, 4] = ea84 + Ae[8, 5] = ea85 + Ae[8, 6] = ea86 + Ae[8, 7] = ea87 + + be_vec = zeros(T, s) + be_vec[1] = eb1 + be_vec[4] = eb4 + be_vec[5] = eb5 + be_vec[6] = eb6 + be_vec[7] = eb7 + be_vec[8] = eb8 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = c5 + c_vec[6] = c6 + c_vec[7] = c7 + c_vec[8] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + btilde_vec[6] = btilde6 + btilde_vec[7] = btilde7 + btilde_vec[8] = btilde8 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[1] = ebtilde1 + ebtilde_vec[4] = ebtilde4 + ebtilde_vec[5] = ebtilde5 + ebtilde_vec[6] = ebtilde6 + ebtilde_vec[7] = ebtilde7 + ebtilde_vec[8] = ebtilde8 + + α_mat = zeros(T2, s, s) + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + α_mat[5, 1] = α51 + α_mat[5, 2] = α52 + α_mat[6, 1] = α61 + α_mat[6, 2] = α62 + α_mat[7, 1] = α71 + α_mat[7, 2] = α72 + α_mat[7, 3] = α73 + α_mat[7, 4] = α74 + α_mat[7, 5] = α75 + α_mat[8, 1] = α81 + α_mat[8, 2] = α82 + α_mat[8, 3] = α83 + α_mat[8, 4] = α84 + α_mat[8, 5] = α85 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 5, s + ) +end + +function ARS343Tableau(T, T2) + γ = convert(T, 4358665215084590 // 10000000000000000) + + s = 4 + + c2 = convert(T2, γ) + c3 = (one(T2) + convert(T2, γ)) / 2 + c4 = one(T2) + + a32_i = (one(T) - γ) / 2 + + b3_i = (one(T) / 2 - 2γ + γ^2) / ((one(T) - γ) / 2) + b2_i = one(T) - γ - b3_i + + Ai = zeros(T, s, s) + Ai[2, 2] = convert(T, γ) + Ai[3, 2] = convert(T, a32_i) + Ai[3, 3] = convert(T, γ) + Ai[4, 2] = convert(T, b2_i) + Ai[4, 3] = convert(T, b3_i) + Ai[4, 4] = convert(T, γ) + + bi_vec = T[zero(T), convert(T, b2_i), convert(T, b3_i), convert(T, γ)] + + ae21 = convert(T, γ) + ae31 = convert(T, 3212788860 // 10000000000) + ae32 = convert(T, 3966543748 // 10000000000) + ae41 = -convert(T, 1058582960 // 10000000000) + ae42 = convert(T, 5529291479 // 10000000000) + ae43 = convert(T, 5529291479 // 10000000000) + + Ae = zeros(T, s, s) + Ae[2, 1] = ae21 + Ae[3, 1] = ae31 + Ae[3, 2] = ae32 + Ae[4, 1] = ae41 + Ae[4, 2] = ae42 + Ae[4, 3] = ae43 + + be_vec = T[zero(T), convert(T, b2_i), convert(T, b3_i), convert(T, γ)] + + c_vec = T2[zero(T2), convert(T2, c2), convert(T2, c3), convert(T2, c4)] + + α_mat = zeros(T2, s, s) + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + nothing, nothing, α_mat, 3, s + ) +end + +# +# Kvaerno3 IMEX Tableau +# + +function Kvaerno3ESDIRKIMEXTableau(T, T2) + γ = convert(T, 0.4358665215) + + a31 = convert(T, 0.490563388419108) + a32 = convert(T, 0.073570090080892) + a41 = convert(T, 0.308809969973036) + a42 = convert(T, 1.490563388254106) + a43 = -convert(T, 1.235239879727145) + + btilde1 = convert(T, 0.181753418446072) + btilde2 = convert(T, -1.416993298173214) + btilde3 = convert(T, 1.671106401227145) + btilde4 = -convert(T, 0.4358665215) + + c3 = convert(T2, 1) + c2 = convert(T2, 2) * convert(T2, 0.4358665215) + θ = c3 / c2 + α31 = ((1 + (-4θ + 3θ^2)) + (6θ * (1 - θ) / c2) * convert(T2, 0.4358665215)) + α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * convert(T2, 0.4358665215)) + α41 = convert(T2, 0.0) + α42 = convert(T2, 0.0) + + s = 4 + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = γ + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = γ + + bi_vec = zeros(T, s) + bi_vec[1] = a41 + bi_vec[2] = a42 + bi_vec[3] = a43 + bi_vec[4] = γ + + Ae = zeros(T, s, s) + be_vec = zeros(T, s) + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = c2 + c_vec[3] = c3 + c_vec[4] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[2] = btilde2 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + + α_mat = zeros(T2, s, s) + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, nothing, α_mat, 3, s + ) +end + +# +# Kvaerno4 IMEX Tableau +# + +function Kvaerno4ESDIRKIMEXTableau(T, T2) + γ = convert(T, 0.4358665215) + + a31 = convert(T, 0.140737774731968) + a32 = convert(T, -0.108365551378832) + a41 = convert(T, 0.102399400616089) + a42 = convert(T, -0.376878452267324) + a43 = convert(T, 0.838612530151233) + a51 = convert(T, 0.157024897860995) + a52 = convert(T, 0.117330441357768) + a53 = convert(T, 0.61667803039168) + a54 = convert(T, -0.326899891110444) + + btilde1 = convert(T, -0.054625497244906) + btilde2 = convert(T, -0.494208893625092) + btilde3 = convert(T, 0.221934499759553) + btilde4 = convert(T, 0.762766412610444) + btilde5 = -convert(T, 0.4358665215) + + c3 = convert(T2, 0.468238744853136) + c4 = convert(T2, 1) + c2 = convert(T2, 2) * convert(T2, 0.4358665215) + + α21 = convert(T2, 2) + α31 = convert(T2, 0.462864521870446) + α32 = convert(T2, 0.537135478129554) + α41 = convert(T2, -0.14714018016178376) + α42 = convert(T2, 1.1471401801617838) + + s = 5 + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = γ + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = γ + Ai[5, 1] = a51 + Ai[5, 2] = a52 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = γ + + bi_vec = zeros(T, s) + bi_vec[1] = a51 + bi_vec[2] = a52 + bi_vec[3] = a53 + bi_vec[4] = a54 + bi_vec[5] = γ + + Ae = zeros(T, s, s) + be_vec = zeros(T, s) + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = c2 + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[2] = btilde2 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + + α_mat = zeros(T2, s, s) + α_mat[2, 1] = α21 + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, nothing, α_mat, 4, s + ) +end + +# +# Kvaerno5 IMEX Tableau +# + +function Kvaerno5ESDIRKIMEXTableau(T, T2) + γ = convert(T, 0.26) + + a31 = convert(T, 0.13) + a32 = convert(T, 0.84033320996790809) + a41 = convert(T, 0.22371961478320505) + a42 = convert(T, 0.47675532319799699) + a43 = -convert(T, 0.06470895363112615) + a51 = convert(T, 0.16648564323248321) + a52 = convert(T, 0.1045001884159172) + a53 = convert(T, 0.03631482272098715) + a54 = -convert(T, 0.13090704451073998) + a61 = convert(T, 0.13855640231268224) + a63 = -convert(T, 0.04245337201752043) + a64 = convert(T, 0.02446657898003141) + a65 = convert(T, 0.61943039072480676) + a71 = convert(T, 0.13659751177640291) + a73 = -convert(T, 0.05496908796538376) + a74 = -convert(T, 0.04118626728321046) + a75 = convert(T, 0.62993304899016403) + a76 = convert(T, 0.06962479448202728) + + btilde1 = convert(T, 0.00195889053627933) + btilde3 = convert(T, 0.01251571594786333) + btilde4 = convert(T, 0.06565284626324187) + btilde5 = -convert(T, 0.01050265826535727) + btilde6 = convert(T, 0.19037520551797272) + btilde7 = -convert(T, 0.26) + + c3 = convert(T2, 1.230333209967908) + c4 = convert(T2, 0.895765984350076) + c5 = convert(T2, 0.436393609858648) + c6 = convert(T2, 1) + c2 = convert(T2, 2) * convert(T2, 0.26) + + α21 = convert(T2, 2) + α31 = convert(T2, -1.366025403784441) + α32 = convert(T2, 2.3660254037844357) + α41 = convert(T2, -0.19650552613122207) + α42 = convert(T2, 0.8113579546496623) + α43 = convert(T2, 0.38514757148155954) + α51 = convert(T2, 0.10375304369958693) + α52 = convert(T2, 0.937994698066431) + α53 = convert(T2, -0.04174774176601781) + α61 = convert(T2, -0.17281112873898072) + α62 = convert(T2, 0.6235784481025847) + α63 = convert(T2, 0.5492326806363959) + + s = 7 + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = γ + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = γ + Ai[5, 1] = a51 + Ai[5, 2] = a52 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = γ + Ai[6, 1] = a61 + Ai[6, 3] = a63 + Ai[6, 4] = a64 + Ai[6, 5] = a65 + Ai[6, 6] = γ + Ai[7, 1] = a71 + Ai[7, 3] = a73 + Ai[7, 4] = a74 + Ai[7, 5] = a75 + Ai[7, 6] = a76 + Ai[7, 7] = γ + + bi_vec = zeros(T, s) + bi_vec[1] = a71 + bi_vec[3] = a73 + bi_vec[4] = a74 + bi_vec[5] = a75 + bi_vec[6] = a76 + bi_vec[7] = γ + + Ae = zeros(T, s, s) + be_vec = zeros(T, s) + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = c2 + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = c5 + c_vec[6] = c6 + c_vec[7] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[1] = btilde1 + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + btilde_vec[6] = btilde6 + btilde_vec[7] = btilde7 + + α_mat = zeros(T2, s, s) + α_mat[2, 1] = α21 + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + α_mat[4, 3] = α43 + α_mat[5, 1] = α51 + α_mat[5, 2] = α52 + α_mat[5, 3] = α53 + α_mat[6, 1] = α61 + α_mat[6, 2] = α62 + α_mat[6, 3] = α63 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, nothing, α_mat, 5, s + ) +end + +# +# KenCarp47 IMEX Tableau +# + +function KenCarp47ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 1235 // 10000) + + a31 = convert(T, 624185399699 // 4186980696204) + a32 = a31 + a41 = convert(T, 1258591069120 // 10082082980243) + a42 = a41 + a43 = -convert(T, 322722984531 // 8455138723562) + a51 = -convert(T, 436103496990 // 5971407786587) + a52 = a51 + a53 = -convert(T, 2689175662187 // 11046760208243) + a54 = convert(T, 4431412449334 // 12995360898505) + a61 = -convert(T, 2207373168298 // 14430576638973) + a62 = a61 + a63 = convert(T, 242511121179 // 3358618340039) + a64 = convert(T, 3145666661981 // 7780404714551) + a65 = convert(T, 5882073923981 // 14490790706663) + a73 = convert(T, 9164257142617 // 17756377923965) + a74 = -convert(T, 10812980402763 // 74029279521829) + a75 = convert(T, 1335994250573 // 5691609445217) + a76 = convert(T, 2273837961795 // 8368240463276) + + btilde3 = convert(T, 216367897668138065439709 // 153341716340757627089664345) + btilde4 = -convert(T, 1719969231640509698414113 // 303097339249411872572263321) + btilde5 = convert(T, 33321949854538424751892 // 16748125370719759490730723) + btilde6 = convert(T, 4033362550194444079469 // 1083063207508329376479196) + btilde7 = -convert(T, 29 // 20000) + + c3 = convert(T2, 4276536705230 // 10142255878289) + c4 = convert(T2, 67 // 200) + c5 = convert(T2, 3 // 40) + c6 = convert(T2, 7 // 10) + + α21 = convert(T2, 2) + α31 = -convert(T2, 796131459065721 // 1125899906842624) + α32 = convert(T2, 961015682954173 // 562949953421312) + α41 = convert(T2, 139710975840363 // 2251799813685248) + α42 = convert(T2, 389969885861609 // 1125899906842624) + α43 = convert(T2, 2664298132243335 // 4503599627370496) + α51 = convert(T2, 6272219723949193 // 9007199254740992) + α52 = convert(T2, 2734979530791799 // 9007199254740992) + α61 = convert(T2, 42616678320173 // 140737488355328) + α62 = -convert(T2, 2617409280098421 // 1125899906842624) + α63 = convert(T2, 1701187880189829 // 562949953421312) + α71 = convert(T2, 4978493057967061 // 2251799813685248) + α72 = convert(T2, 7230365118049293 // 9007199254740992) + α73 = -convert(T2, 6826045129237249 // 18014398509481984) + α74 = -convert(T2, 2388848894891525 // 1125899906842624) + α75 = -convert(T2, 4796744191239075 // 2251799813685248) + α76 = convert(T2, 2946706549191323 // 1125899906842624) + + ea21 = convert(T, 247 // 1000) + ea31 = convert(T, 247 // 4000) + ea32 = convert(T, 2694949928731 // 7487940209513) + ea41 = convert(T, 464650059369 // 8764239774964) + ea42 = convert(T, 878889893998 // 2444806327765) + ea43 = -convert(T, 952945855348 // 12294611323341) + ea51 = convert(T, 476636172619 // 8159180917465) + ea52 = -convert(T, 1271469283451 // 7793814740893) + ea53 = -convert(T, 859560642026 // 4356155882851) + ea54 = convert(T, 1723805262919 // 4571918432560) + ea61 = convert(T, 6338158500785 // 11769362343261) + ea62 = -convert(T, 4970555480458 // 10924838743837) + ea63 = convert(T, 3326578051521 // 2647936831840) + ea64 = -convert(T, 880713585975 // 1841400956686) + ea65 = -convert(T, 1428733748635 // 8843423958496) + ea71 = convert(T, 760814592956 // 3276306540349) + ea72 = convert(T, 760814592956 // 3276306540349) + ea73 = -convert(T, 47223648122716 // 6934462133451) + ea74 = convert(T, 71187472546993 // 9669769126921) + ea75 = -convert(T, 13330509492149 // 9695768672337) + ea76 = convert(T, 11565764226357 // 8513123442827) + + eb3 = convert(T, 9164257142617 // 17756377923965) + eb4 = -convert(T, 10812980402763 // 74029279521829) + eb5 = convert(T, 1335994250573 // 5691609445217) + eb6 = convert(T, 2273837961795 // 8368240463276) + eb7 = convert(T, 247 // 2000) + + ebtilde3 = convert(T, 216367897668138065439709 // 153341716340757627089664345) + ebtilde4 = -convert(T, 1719969231640509698414113 // 303097339249411872572263321) + ebtilde5 = convert(T, 33321949854538424751892 // 16748125370719759490730723) + ebtilde6 = convert(T, 4033362550194444079469 // 1083063207508329376479196) + ebtilde7 = -convert(T, 29 // 20000) + + s = 7 + Ai = zeros(T, s, s) + Ai[2, 1] = convert(T, γ) + Ai[2, 2] = convert(T, γ) + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = convert(T, γ) + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = convert(T, γ) + Ai[5, 1] = a51 + Ai[5, 2] = a52 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = convert(T, γ) + Ai[6, 1] = a61 + Ai[6, 2] = a62 + Ai[6, 3] = a63 + Ai[6, 4] = a64 + Ai[6, 5] = a65 + Ai[6, 6] = convert(T, γ) + Ai[7, 3] = a73 + Ai[7, 4] = a74 + Ai[7, 5] = a75 + Ai[7, 6] = a76 + Ai[7, 7] = convert(T, γ) + + bi_vec = zeros(T, s) + bi_vec[3] = a73 + bi_vec[4] = a74 + bi_vec[5] = a75 + bi_vec[6] = a76 + bi_vec[7] = convert(T, γ) + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 2] = ea42 + Ae[4, 3] = ea43 + Ae[5, 1] = ea51 + Ae[5, 2] = ea52 + Ae[5, 3] = ea53 + Ae[5, 4] = ea54 + Ae[6, 1] = ea61 + Ae[6, 2] = ea62 + Ae[6, 3] = ea63 + Ae[6, 4] = ea64 + Ae[6, 5] = ea65 + Ae[7, 1] = ea71 + Ae[7, 2] = ea72 + Ae[7, 3] = ea73 + Ae[7, 4] = ea74 + Ae[7, 5] = ea75 + Ae[7, 6] = ea76 + + be_vec = zeros(T, s) + be_vec[3] = eb3 + be_vec[4] = eb4 + be_vec[5] = eb5 + be_vec[6] = eb6 + be_vec[7] = eb7 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = c5 + c_vec[6] = c6 + c_vec[7] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + btilde_vec[6] = btilde6 + btilde_vec[7] = btilde7 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[3] = ebtilde3 + ebtilde_vec[4] = ebtilde4 + ebtilde_vec[5] = ebtilde5 + ebtilde_vec[6] = ebtilde6 + ebtilde_vec[7] = ebtilde7 + + α_mat = zeros(T2, s, s) + α_mat[2, 1] = α21 + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + α_mat[4, 3] = α43 + α_mat[5, 1] = α51 + α_mat[5, 2] = α52 + α_mat[6, 1] = α61 + α_mat[6, 2] = α62 + α_mat[6, 3] = α63 + α_mat[7, 1] = α71 + α_mat[7, 2] = α72 + α_mat[7, 3] = α73 + α_mat[7, 4] = α74 + α_mat[7, 5] = α75 + α_mat[7, 6] = α76 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 4, s + ) +end + +# +# KenCarp58 IMEX Tableau +# + +function KenCarp58ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 2 // 9) + + a31 = convert(T, 2366667076620 // 8822750406821) + a32 = a31 + a41 = -convert(T, 257962897183 // 4451812247028) + a42 = a41 + a43 = convert(T, 128530224461 // 14379561246022) + a51 = -convert(T, 486229321650 // 11227943450093) + a52 = a51 + a53 = -convert(T, 225633144460 // 6633558740617) + a54 = convert(T, 1741320951451 // 6824444397158) + a61 = convert(T, 621307788657 // 4714163060173) + a62 = a61 + a63 = -convert(T, 125196015625 // 3866852212004) + a64 = convert(T, 940440206406 // 7593089888465) + a65 = convert(T, 961109811699 // 6734810228204) + a71 = convert(T, 2036305566805 // 6583108094622) + a72 = a71 + a73 = -convert(T, 3039402635899 // 4450598839912) + a74 = -convert(T, 1829510709469 // 31102090912115) + a75 = -convert(T, 286320471013 // 6931253422520) + a76 = convert(T, 8651533662697 // 9642993110008) + a83 = convert(T, 3517720773327 // 20256071687669) + a84 = convert(T, 4569610470461 // 17934693873752) + a85 = convert(T, 2819471173109 // 11655438449929) + a86 = convert(T, 3296210113763 // 10722700128969) + a87 = -convert(T, 1142099968913 // 5710983926999) + + btilde3 = -convert(T, 18652552508630163520943320 // 168134443655105334713783643) + btilde4 = convert(T, 141161430501477620145807 // 319735394533244397237135736) + btilde5 = -convert(T, 207757214437709595456056 // 72283007456311581445415925) + btilde6 = convert(T, 13674542533282477231637762 // 149163814411398370516486131) + btilde7 = convert(T, 11939168497868428048898551 // 210101209758476969753215083) + btilde8 = -convert(T, 1815023333875 // 51666766064334) + + c3 = convert(T2, 6456083330201 // 8509243623797) + c4 = convert(T2, 1632083962415 // 14158861528103) + c5 = convert(T2, 6365430648612 // 17842476412687) + c6 = convert(T2, 18 // 25) + c7 = convert(T2, 191 // 200) + + α31 = -convert(T2, 796131459065721 // 1125899906842624) + α32 = convert(T2, 961015682954173 // 562949953421312) + α41 = convert(T2, 3335563016633385 // 4503599627370496) + α42 = convert(T2, 2336073221474223 // 9007199254740992) + α51 = convert(T2, 1777088537295433 // 9007199254740992) + α52 = convert(T2, 7230110717445555 // 9007199254740992) + α61 = convert(T2, 305461594360167 // 36028797018963968) + α62 = convert(T2, 3700851199347703 // 36028797018963968) + α63 = convert(T2, 8005621056314023 // 9007199254740992) + α71 = convert(T2, 247009276011491 // 9007199254740992) + α72 = -convert(T2, 6222030107065861 // 9007199254740992) + α73 = convert(T2, 1872777510724421 // 1125899906842624) + α81 = convert(T2, 180631849429283 // 36028797018963968) + α82 = -convert(T2, 3454740038041085 // 36028797018963968) + α83 = convert(T2, 476708848972457 // 2251799813685248) + α84 = convert(T2, 5255799236757313 // 288230376151711744) + α85 = convert(T2, 3690914796734375 // 288230376151711744) + α86 = -convert(T2, 5010195363762467 // 18014398509481984) + α87 = convert(T2, 5072201887169367 // 4503599627370496) + + ea21 = convert(T, 4 // 9) + ea31 = convert(T, 1 // 9) + ea32 = convert(T, 1183333538310 // 1827251437969) + ea41 = convert(T, 895379019517 // 9750411845327) + ea42 = convert(T, 477606656805 // 13473228687314) + ea43 = -convert(T, 112564739183 // 9373365219272) + ea51 = -convert(T, 4458043123994 // 13015289567637) + ea52 = -convert(T, 2500665203865 // 9342069639922) + ea53 = convert(T, 983347055801 // 8893519644487) + ea54 = convert(T, 2185051477207 // 2551468980502) + ea61 = -convert(T, 167316361917 // 17121522574472) + ea62 = convert(T, 1605541814917 // 7619724128744) + ea63 = convert(T, 991021770328 // 13052792161721) + ea64 = convert(T, 2342280609577 // 11279663441611) + ea65 = convert(T, 3012424348531 // 12792462456678) + ea71 = convert(T, 6680998715867 // 14310383562358) + ea72 = convert(T, 5029118570809 // 3897454228471) + ea73 = convert(T, 2415062538259 // 6382199904604) + ea74 = -convert(T, 3924368632305 // 6964820224454) + ea75 = -convert(T, 4331110370267 // 15021686902756) + ea76 = -convert(T, 3944303808049 // 11994238218192) + ea81 = convert(T, 2193717860234 // 3570523412979) + ea82 = convert(T, 2193717860234 // 3570523412979) + ea83 = convert(T, 5952760925747 // 18750164281544) + ea84 = -convert(T, 4412967128996 // 6196664114337) + ea85 = convert(T, 4151782504231 // 36106512998704) + ea86 = convert(T, 572599549169 // 6265429158920) + ea87 = -convert(T, 457874356192 // 11306498036315) + + eb3 = convert(T, 3517720773327 // 20256071687669) + eb4 = convert(T, 4569610470461 // 17934693873752) + eb5 = convert(T, 2819471173109 // 11655438449929) + eb6 = convert(T, 3296210113763 // 10722700128969) + eb7 = -convert(T, 1142099968913 // 5710983926999) + eb8 = convert(T, 2 // 9) + + ebtilde3 = -convert(T, 18652552508630163520943320 // 168134443655105334713783643) + ebtilde4 = convert(T, 141161430501477620145807 // 319735394533244397237135736) + ebtilde5 = -convert(T, 207757214437709595456056 // 72283007456311581445415925) + ebtilde6 = convert(T, 13674542533282477231637762 // 149163814411398370516486131) + ebtilde7 = convert(T, 11939168497868428048898551 // 210101209758476969753215083) + ebtilde8 = -convert(T, 1815023333875 // 51666766064334) + + s = 8 + Ai = zeros(T, s, s) + Ai[2, 1] = convert(T, γ) + Ai[2, 2] = convert(T, γ) + Ai[3, 1] = a31 + Ai[3, 2] = a32 + Ai[3, 3] = convert(T, γ) + Ai[4, 1] = a41 + Ai[4, 2] = a42 + Ai[4, 3] = a43 + Ai[4, 4] = convert(T, γ) + Ai[5, 1] = a51 + Ai[5, 2] = a52 + Ai[5, 3] = a53 + Ai[5, 4] = a54 + Ai[5, 5] = convert(T, γ) + Ai[6, 1] = a61 + Ai[6, 2] = a62 + Ai[6, 3] = a63 + Ai[6, 4] = a64 + Ai[6, 5] = a65 + Ai[6, 6] = convert(T, γ) + Ai[7, 1] = a71 + Ai[7, 2] = a72 + Ai[7, 3] = a73 + Ai[7, 4] = a74 + Ai[7, 5] = a75 + Ai[7, 6] = a76 + Ai[7, 7] = convert(T, γ) + Ai[8, 3] = a83 + Ai[8, 4] = a84 + Ai[8, 5] = a85 + Ai[8, 6] = a86 + Ai[8, 7] = a87 + Ai[8, 8] = convert(T, γ) + + bi_vec = zeros(T, s) + bi_vec[3] = a83 + bi_vec[4] = a84 + bi_vec[5] = a85 + bi_vec[6] = a86 + bi_vec[7] = a87 + bi_vec[8] = convert(T, γ) + + Ae = zeros(T, s, s) + Ae[2, 1] = ea21 + Ae[3, 1] = ea31 + Ae[3, 2] = ea32 + Ae[4, 1] = ea41 + Ae[4, 2] = ea42 + Ae[4, 3] = ea43 + Ae[5, 1] = ea51 + Ae[5, 2] = ea52 + Ae[5, 3] = ea53 + Ae[5, 4] = ea54 + Ae[6, 1] = ea61 + Ae[6, 2] = ea62 + Ae[6, 3] = ea63 + Ae[6, 4] = ea64 + Ae[6, 5] = ea65 + Ae[7, 1] = ea71 + Ae[7, 2] = ea72 + Ae[7, 3] = ea73 + Ae[7, 4] = ea74 + Ae[7, 5] = ea75 + Ae[7, 6] = ea76 + Ae[8, 1] = ea81 + Ae[8, 2] = ea82 + Ae[8, 3] = ea83 + Ae[8, 4] = ea84 + Ae[8, 5] = ea85 + Ae[8, 6] = ea86 + Ae[8, 7] = ea87 + + be_vec = zeros(T, s) + be_vec[3] = eb3 + be_vec[4] = eb4 + be_vec[5] = eb5 + be_vec[6] = eb6 + be_vec[7] = eb7 + be_vec[8] = eb8 + + c_vec = zeros(T2, s) + c_vec[1] = zero(T2) + c_vec[2] = convert(T2, 2γ) + c_vec[3] = c3 + c_vec[4] = c4 + c_vec[5] = c5 + c_vec[6] = c6 + c_vec[7] = c7 + c_vec[8] = one(T2) + + btilde_vec = zeros(T, s) + btilde_vec[3] = btilde3 + btilde_vec[4] = btilde4 + btilde_vec[5] = btilde5 + btilde_vec[6] = btilde6 + btilde_vec[7] = btilde7 + btilde_vec[8] = btilde8 + + ebtilde_vec = zeros(T, s) + ebtilde_vec[3] = ebtilde3 + ebtilde_vec[4] = ebtilde4 + ebtilde_vec[5] = ebtilde5 + ebtilde_vec[6] = ebtilde6 + ebtilde_vec[7] = ebtilde7 + ebtilde_vec[8] = ebtilde8 + + α_mat = zeros(T2, s, s) + α_mat[3, 1] = α31 + α_mat[3, 2] = α32 + α_mat[4, 1] = α41 + α_mat[4, 2] = α42 + α_mat[5, 1] = α51 + α_mat[5, 2] = α52 + α_mat[6, 1] = α61 + α_mat[6, 2] = α62 + α_mat[6, 3] = α63 + α_mat[7, 1] = α71 + α_mat[7, 2] = α72 + α_mat[7, 3] = α73 + α_mat[8, 1] = α81 + α_mat[8, 2] = α82 + α_mat[8, 3] = α83 + α_mat[8, 4] = α84 + α_mat[8, 5] = α85 + α_mat[8, 6] = α86 + α_mat[8, 7] = α87 + + return ESDIRKIMEXTableau( + Ai, bi_vec, Ae, be_vec, c_vec, + btilde_vec, ebtilde_vec, α_mat, 5, s + ) +end + +# +# Helper to build an ESDIRKIMEXTableau for pure-implicit ESDIRK methods +# (no explicit tableau, no extrapolation guess) +# +function _pure_esdirk_to_imex_tableau(Ai_mat::Matrix{T}, c_vec::Vector{T2}, + btilde_vec::Vector{T}, order::Int) where {T, T2} + s = size(Ai_mat, 1) + bi_vec = Ai_mat[s, :] + Ae = zeros(T, s, s) + be_vec = zeros(T, s) + return ESDIRKIMEXTableau( + Ai_mat, bi_vec, Ae, be_vec, c_vec, + btilde_vec, nothing, nothing, order, s + ) +end + +# +# ESDIRK54I8L2SA IMEX Tableau (8 stages, order 5) +# +function ESDIRK54I8L2SAESDIRKIMEXTableau(T, T2) + tab = ESDIRK54I8L2SATableau(T, T2) + s = 8 + γ = convert(T, tab.γ) + + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = tab.a31 + Ai[3, 2] = tab.a32 + Ai[3, 3] = γ + Ai[4, 1] = tab.a41 + Ai[4, 2] = tab.a42 + Ai[4, 3] = tab.a43 + Ai[4, 4] = γ + Ai[5, 1] = tab.a51 + Ai[5, 2] = tab.a52 + Ai[5, 3] = tab.a53 + Ai[5, 4] = tab.a54 + Ai[5, 5] = γ + Ai[6, 1] = tab.a61 + Ai[6, 2] = tab.a62 + Ai[6, 3] = tab.a63 + Ai[6, 4] = tab.a64 + Ai[6, 5] = tab.a65 + Ai[6, 6] = γ + Ai[7, 1] = tab.a71 + Ai[7, 2] = tab.a72 + Ai[7, 3] = tab.a73 + Ai[7, 4] = tab.a74 + Ai[7, 5] = tab.a75 + Ai[7, 6] = tab.a76 + Ai[7, 7] = γ + Ai[8, 1] = tab.a81 + Ai[8, 2] = tab.a82 + Ai[8, 3] = tab.a83 + Ai[8, 4] = tab.a84 + Ai[8, 5] = tab.a85 + Ai[8, 6] = tab.a86 + Ai[8, 7] = tab.a87 + Ai[8, 8] = γ + + c_vec = T2[zero(T2), convert(T2, 2) * convert(T2, tab.γ), + tab.c3, tab.c4, tab.c5, tab.c6, tab.c7, one(T2)] + + btilde_vec = T[tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4, + tab.btilde5, tab.btilde6, tab.btilde7, tab.btilde8] + + return _pure_esdirk_to_imex_tableau(Ai, c_vec, btilde_vec, 5) +end + +# +# ESDIRK436L2SA2 IMEX Tableau (6 stages, order 4) +# +function ESDIRK436L2SA2ESDIRKIMEXTableau(T, T2) + tab = ESDIRK436L2SA2Tableau(T, T2) + s = 6 + γ = tab.γ + + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = tab.a31 + Ai[3, 2] = tab.a32 + Ai[3, 3] = γ + Ai[4, 1] = tab.a41 + Ai[4, 2] = tab.a42 + Ai[4, 3] = tab.a43 + Ai[4, 4] = γ + Ai[5, 1] = tab.a51 + Ai[5, 2] = tab.a52 + Ai[5, 3] = tab.a53 + Ai[5, 4] = tab.a54 + Ai[5, 5] = γ + Ai[6, 1] = tab.a61 + Ai[6, 2] = tab.a62 + Ai[6, 3] = tab.a63 + Ai[6, 4] = tab.a64 + Ai[6, 5] = tab.a65 + Ai[6, 6] = γ + + c_vec = T2[zero(T2), convert(T2, 2) * convert(T2, γ), + tab.c3, tab.c4, tab.c5, tab.c6] + + btilde_vec = T[tab.btilde1, tab.btilde2, tab.btilde3, + tab.btilde4, tab.btilde5, tab.btilde6] + + return _pure_esdirk_to_imex_tableau(Ai, c_vec, btilde_vec, 4) +end + +# +# ESDIRK437L2SA IMEX Tableau (7 stages, order 4) +# +function ESDIRK437L2SAESDIRKIMEXTableau(T, T2) + tab = ESDIRK437L2SATableau(T, T2) + s = 7 + γ = tab.γ + + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = tab.a31 + Ai[3, 2] = tab.a32 + Ai[3, 3] = γ + Ai[4, 1] = tab.a41 + Ai[4, 2] = tab.a42 + Ai[4, 3] = tab.a43 + Ai[4, 4] = γ + Ai[5, 1] = tab.a51 + Ai[5, 2] = tab.a52 + Ai[5, 3] = tab.a53 + Ai[5, 4] = tab.a54 + Ai[5, 5] = γ + Ai[6, 1] = tab.a61 + Ai[6, 2] = tab.a62 + Ai[6, 3] = tab.a63 + Ai[6, 4] = tab.a64 + Ai[6, 5] = tab.a65 + Ai[6, 6] = γ + Ai[7, 1] = tab.a71 + Ai[7, 2] = tab.a72 + Ai[7, 3] = tab.a73 + Ai[7, 4] = tab.a74 + Ai[7, 5] = tab.a75 + Ai[7, 6] = tab.a76 + Ai[7, 7] = γ + + c_vec = T2[zero(T2), convert(T2, 2) * convert(T2, γ), + tab.c3, tab.c4, tab.c5, tab.c6, tab.c7] + + btilde_vec = T[tab.btilde1, tab.btilde2, tab.btilde3, + tab.btilde4, tab.btilde5, tab.btilde6, tab.btilde7] + + return _pure_esdirk_to_imex_tableau(Ai, c_vec, btilde_vec, 4) +end + +# +# ESDIRK547L2SA2 IMEX Tableau (7 stages, order 5) +# +function ESDIRK547L2SA2ESDIRKIMEXTableau(T, T2) + tab = ESDIRK547L2SA2Tableau(T, T2) + s = 7 + γ = tab.γ + + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = tab.a31 + Ai[3, 2] = tab.a32 + Ai[3, 3] = γ + Ai[4, 1] = tab.a41 + Ai[4, 2] = tab.a42 + Ai[4, 3] = tab.a43 + Ai[4, 4] = γ + Ai[5, 1] = tab.a51 + Ai[5, 2] = tab.a52 + Ai[5, 3] = tab.a53 + Ai[5, 4] = tab.a54 + Ai[5, 5] = γ + Ai[6, 1] = tab.a61 + Ai[6, 2] = tab.a62 + Ai[6, 3] = tab.a63 + Ai[6, 4] = tab.a64 + Ai[6, 5] = tab.a65 + Ai[6, 6] = γ + Ai[7, 1] = tab.a71 + Ai[7, 2] = tab.a72 + Ai[7, 3] = tab.a73 + Ai[7, 4] = tab.a74 + Ai[7, 5] = tab.a75 + Ai[7, 6] = tab.a76 + Ai[7, 7] = γ + + c_vec = T2[zero(T2), convert(T2, 2) * convert(T2, γ), + tab.c3, tab.c4, tab.c5, tab.c6, tab.c7] + + btilde_vec = T[tab.btilde1, tab.btilde2, tab.btilde3, + tab.btilde4, tab.btilde5, tab.btilde6, tab.btilde7] + + return _pure_esdirk_to_imex_tableau(Ai, c_vec, btilde_vec, 5) +end + +# +# ESDIRK659L2SA IMEX Tableau (9 stages, order 6) +# Note: stage 9 has a91=a92=a93=0 (only depends on stages 4-8) +# +function ESDIRK659L2SAESDIRKIMEXTableau(T, T2) + tab = ESDIRK659L2SATableau(T, T2) + s = 9 + γ = tab.γ + + Ai = zeros(T, s, s) + Ai[2, 1] = γ + Ai[2, 2] = γ + Ai[3, 1] = tab.a31 + Ai[3, 2] = tab.a32 + Ai[3, 3] = γ + Ai[4, 1] = tab.a41 + Ai[4, 2] = tab.a42 + Ai[4, 3] = tab.a43 + Ai[4, 4] = γ + Ai[5, 1] = tab.a51 + Ai[5, 2] = tab.a52 + Ai[5, 3] = tab.a53 + Ai[5, 4] = tab.a54 + Ai[5, 5] = γ + Ai[6, 1] = tab.a61 + Ai[6, 2] = tab.a62 + Ai[6, 3] = tab.a63 + Ai[6, 4] = tab.a64 + Ai[6, 5] = tab.a65 + Ai[6, 6] = γ + Ai[7, 1] = tab.a71 + Ai[7, 2] = tab.a72 + Ai[7, 3] = tab.a73 + Ai[7, 4] = tab.a74 + Ai[7, 5] = tab.a75 + Ai[7, 6] = tab.a76 + Ai[7, 7] = γ + Ai[8, 1] = tab.a81 + Ai[8, 2] = tab.a82 + Ai[8, 3] = tab.a83 + Ai[8, 4] = tab.a84 + Ai[8, 5] = tab.a85 + Ai[8, 6] = tab.a86 + Ai[8, 7] = tab.a87 + Ai[8, 8] = γ + # Stage 9: a91=a92=a93=0 (zeros from initialization) + Ai[9, 4] = tab.a94 + Ai[9, 5] = tab.a95 + Ai[9, 6] = tab.a96 + Ai[9, 7] = tab.a97 + Ai[9, 8] = tab.a98 + Ai[9, 9] = γ + + c_vec = T2[zero(T2), convert(T2, 2) * convert(T2, γ), + tab.c3, tab.c4, tab.c5, tab.c6, tab.c7, tab.c8, tab.c9] + + btilde_vec = T[tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4, + tab.btilde5, tab.btilde6, tab.btilde7, tab.btilde8, tab.btilde9] + + return _pure_esdirk_to_imex_tableau(Ai, c_vec, btilde_vec, 6) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl index 5ac57d8ac10..29ea4301f09 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl @@ -1,146 +1,3 @@ -mutable struct Kvaerno3ConstantCache{Tab, N} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::Kvaerno3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, 2tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return Kvaerno3ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno3Cache{uType, rateType, uNoUnitsType, Tab, N, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::Kvaerno3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, 2tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return Kvaerno3Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, atmp, nlsolver, tab, alg.step_limiter! - ) -end - -@cache mutable struct KenCarp3ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - - return KenCarp3ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp3Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter, - } <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - k1::kType - k2::kType - k3::kType - k4::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::KenCarp3, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp3Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - uf = UJacobianWrapper(f, t, p) - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return KenCarp3Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2, - k3, k4, atmp, nlsolver, tab, alg.step_limiter! - ) -end - @cache mutable struct CFNLIRK3ConstantCache{N, Tab} <: SDIRKConstantCache nlsolver::N tab::Tab @@ -208,523 +65,3 @@ function alg_cache( return CFNLIRK3Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver, tab) end - -@cache mutable struct Kvaerno4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::Kvaerno4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return Kvaerno4ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno4Cache{uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::Kvaerno4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return Kvaerno4Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab, alg.step_limiter! - ) -end - -@cache mutable struct KenCarp4ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return KenCarp4ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp4Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter, - } <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -@truncate_stacktrace KenCarp4Cache 1 - -function alg_cache( - alg::KenCarp4, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - uf = UJacobianWrapper(f, t, p) - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return KenCarp4Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, k1, k2, k3, k4, k5, k6, atmp, - nlsolver, tab, alg.step_limiter! - ) -end - -@cache mutable struct Kvaerno5ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::Kvaerno5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - - return Kvaerno5ConstantCache(nlsolver, tab) -end - -@cache mutable struct Kvaerno5Cache{uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::Kvaerno5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = Kvaerno5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return Kvaerno5Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, - z₇, atmp, nlsolver, tab, alg.step_limiter! - ) -end - -@cache mutable struct KenCarp5ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - - return KenCarp5ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp5Cache{ - uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter, - } <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - k8::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::KenCarp5, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - k8 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return KenCarp5Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, - k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab, alg.step_limiter! - ) -end - -@cache mutable struct KenCarp47ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp47Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - - return KenCarp47ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp47Cache{uType, rateType, uNoUnitsType, N, Tab, kType} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end -@truncate_stacktrace KenCarp47Cache 1 - -function alg_cache( - alg::KenCarp47, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp47Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return KenCarp47Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, - k1, k2, k3, k4, k5, k6, k7, atmp, nlsolver, tab - ) -end - -@cache mutable struct KenCarp58ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp58Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - - return KenCarp58ConstantCache(nlsolver, tab) -end - -@cache mutable struct KenCarp58Cache{uType, rateType, uNoUnitsType, N, Tab, kType} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - k6::kType - k7::kType - k8::kType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -@truncate_stacktrace KenCarp58Cache 1 - -function alg_cache( - alg::KenCarp58, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = KenCarp58Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c3 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - k6 = zero(u) - k7 = zero(u) - k8 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - end - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return KenCarp58Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, - k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab - ) -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl index c4a8110d3e7..2210d91a711 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl @@ -1,168 +1,16 @@ @muladd function perform_step!( - integrator, cache::Kvaerno3ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab - alg = unwrap_alg(integrator, true) - - # calculate W - markfirststage!(nlsolver) - - # FSAL Step 1 - nlsolver.z = z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - # Guess is from Hermite derivative on z₁ and z₂ - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = a31 * z₁ + a32 * z₂ + γ * z₃ # use yhat as prediction - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno3Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if cache isa Kvaerno3Cache - @.. broadcast = false z₄ = a31 * z₁ + a32 * z₂ + γ * z₃ # use yhat as prediction - elseif cache isa KenCarp3Cache - (; α41, α42) = cache.tab - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - end - nlsolver.z = z₄ - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₄ - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₄ / dt -end - -@muladd function perform_step!( - integrator, cache::KenCarp3ConstantCache, + integrator, cache::CFNLIRK3ConstantCache, repeat_step = false ) (; t, dt, uprev, u, p) = integrator nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32, ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4, ebtilde1, ebtilde2, ebtilde3, ebtilde4) = cache.tab + (; γ, a31, a32, a41, a42, a43, c2, c3, ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab alg = unwrap_alg(integrator, true) f2 = nothing k1 = nothing k2 = nothing k3 = nothing - k4 = nothing if integrator.f isa SplitFunction f = integrator.f.f1 f2 = integrator.f.f2 @@ -179,7 +27,7 @@ end if integrator.f isa SplitFunction # Explicit tableau is not FSAL # Make this not compute on repeat - z₁ = dt * f(uprev, p, t) + z₁ = dt .* f(uprev, p, t) else # FSAL Step 1 z₁ = dt * integrator.fsalfirst @@ -190,15 +38,15 @@ end # TODO: Add extrapolation for guess nlsolver.z = z₂ = z₁ - nlsolver.tmp = uprev + γ * z₁ + nlsolver.tmp = uprev if integrator.f isa SplitFunction # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ + k1 = dt .* f2(uprev, p, t) nlsolver.tmp += ea21 * k1 end - nlsolver.c = 2γ + nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return @@ -207,12 +55,11 @@ end if integrator.f isa SplitFunction z₃ = z₂ u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) + k2 = dt * f2(u, p, t + c2 * dt) integrator.stats.nf2 += 1 tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ + z₃ = z₂ tmp = uprev + a31 * z₁ + a32 * z₂ end nlsolver.z = z₃ @@ -225,14 +72,13 @@ end ################################## Solve Step 4 if integrator.f isa SplitFunction - z₄ = z₂ + z₄ = z₃ u = nlsolver.tmp + γ * z₃ k3 = dt * f2(u, p, t + c3 * dt) integrator.stats.nf2 += 1 tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 else - (; α41, α42) = cache.tab - z₄ = α41 * z₁ + α42 * z₂ + z₄ = z₃ tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ end nlsolver.z = z₄ @@ -252,26 +98,6 @@ end ################################### Finalize - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - ebtilde1 * k1 + ebtilde2 * k2 + ebtilde3 * k3 + ebtilde4 * k4 - else - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - if integrator.f isa SplitFunction integrator.k[1] = integrator.fsalfirst integrator.fsallast = integrator.f(u, p, t + dt) @@ -284,13 +110,13 @@ end integrator.u = u end -@muladd function perform_step!(integrator, cache::KenCarp3Cache, repeat_step = false) +@muladd function perform_step!(integrator, cache::CFNLIRK3Cache, repeat_step = false) (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver, step_limiter!) = cache + (; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver) = cache (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32) = cache.tab + (; γ, a31, a32, a41, a42, a43, c2, c3) = cache.tab (; ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - (; ebtilde1, ebtilde2, ebtilde3, ebtilde4) = cache.tab + alg = unwrap_alg(integrator, true) f2 = nothing @@ -307,13 +133,11 @@ end markfirststage!(nlsolver) if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat f(z₁, integrator.uprev, p, integrator.t) z₁ .*= dt else # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst + @..z₁ = dt * integrator.fsalfirst end ##### Step 2 @@ -322,15 +146,15 @@ end copyto!(z₂, z₁) nlsolver.z = z₂ - @.. broadcast = false tmp = uprev + γ * z₁ + @..tmp = uprev if integrator.f isa SplitFunction # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 + @..k1 = dt * integrator.fsalfirst - z₁ + @..tmp += ea21 * k1 end - nlsolver.c = 2γ + nlsolver.c = c2 z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return isnewton(nlsolver) && set_new_W!(nlsolver, false) @@ -339,15 +163,14 @@ end if integrator.f isa SplitFunction z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) + @..u = tmp + γ * z₂ + f2(k2, u, p, t + c2 * dt) k2 .*= dt integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 + @..tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + @..z₃ = z₂ + @..tmp = uprev + a31 * z₁ + a32 * z₂ end nlsolver.z = z₃ @@ -359,16 +182,15 @@ end if integrator.f isa SplitFunction z₄ .= z₂ - @.. broadcast = false u = tmp + γ * z₃ + @..u = tmp + γ * z₃ f2(k3, u, p, t + c3 * dt) k3 .*= dt integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + + @..tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 else - (; α41, α42) = cache.tab - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + @..z₄ = z₂ + @..tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ end nlsolver.z = z₄ @@ -376,2361 +198,18 @@ end z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - @.. broadcast = false u = tmp + γ * z₄ + @..u = tmp + γ * z₄ if integrator.f isa SplitFunction f2(k4, u, p, t + dt) k4 .*= dt integrator.stats.nf2 += 1 - @.. broadcast = false u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + + @..u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + eb3 * k3 + eb4 * k4 end - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + - btilde4 * z₄ + ebtilde1 * k1 + ebtilde2 * k2 + - ebtilde3 * k3 + ebtilde4 * k4 - else - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + - btilde4 * z₄ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₄ / dt - end -end - -@muladd function perform_step!( - integrator, cache::CFNLIRK3ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, c2, c3, ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt .* f2(uprev, p, t) - nlsolver.tmp += ea21 * k1 - end - - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + c2 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - z₃ = z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₃ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = z₃ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - if integrator.f isa SplitFunction - k4 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + - eb3 * k3 + eb4 * k4 - end - - ################################### Finalize - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₄ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::CFNLIRK3Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, c2, c3) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, eb1, eb2, eb3, eb4) = cache.tab - - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation for guess - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + c2 * dt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - @.. broadcast = false z₃ = z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₂ - @.. broadcast = false u = tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast = false z₄ = z₂ - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = 1 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₄ - if integrator.f isa SplitFunction - f2(k4, u, p, t + dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + - eb2 * k2 + eb3 * k3 + eb4 * k4 - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₄ / dt - end -end - -@muladd function perform_step!( - integrator, cache::Kvaerno4ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, c3, c4) = cache.tab - (; α21, α31, α32, α41, α42) = cache.tab - (; btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(u) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = α41 * z₁ + α42 * z₂ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat2 for prediction - nlsolver.z = z₅ = a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₄ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno4Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, c3, c4) = cache.tab - (; α21, α31, α32, α41, α42) = cache.tab - (; btilde1, btilde2, btilde3, btilde4, btilde5) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - nlsolver.z = z₄ - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - # Use yhat prediction - @.. broadcast = false z₅ = a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ - nlsolver.z = z₅ - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₅ - - step_limiter!(u, integrator, p, t + dt) - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₅ / dt -end - -@muladd function perform_step!( - integrator, cache::KenCarp4ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, c3, c4, c5) = cache.tab - (; α31, α32, α41, α42, α51, α52, α53, α54, α61, α62, α63, α64, α65) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; eb1, eb3, eb4, eb5, eb6) = cache.tab - (; ebtilde1, ebtilde3, ebtilde4, ebtilde5, ebtilde6) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₂ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.tmp = tmp - nlsolver.c = c4 - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₄ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ + α54 * z₄ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.tmp = tmp - nlsolver.c = c5 - - u = nlsolver.tmp + γ * z₅ - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₅ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + ea62 * k2 + - ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ + α64 * z₄ + α65 * z₅ - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.tmp = tmp - nlsolver.c = 1 - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₆ - if integrator.f isa SplitFunction - k6 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ + eb1 * k1 + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + - ebtilde1 * k1 + ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 - else - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₆ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp4Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; k1, k2, k3, k4, k5, k6) = cache - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, c3, c4, c5) = cache.tab - (; α31, α32, α41, α42, α51, α52, α53, α54, α61, α62, α63, α64, α65) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; eb1, eb3, eb4, eb5, eb6) = cache.tab - (; ebtilde1, ebtilde3, ebtilde4, ebtilde5, ebtilde6) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - - nlsolver.c = 2γ - markfirststage!(nlsolver) - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₂ - @.. broadcast = false u = tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₄ - @.. broadcast = false u = tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast = false z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ + α54 * z₄ - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₅ - @.. broadcast = false u = tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + - ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - @.. broadcast = false z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ + α64 * z₄ + α65 * z₅ - @.. broadcast = false tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = 1 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₆ - if integrator.f isa SplitFunction - f2(k6, u, p, t + dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false u = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ + - eb1 * k1 + eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 - end - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast = false tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + ebtilde1 * k1 + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 - else - @.. broadcast = false tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₆ / dt - end -end - -@muladd function perform_step!( - integrator, cache::Kvaerno5ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α53, α61, α62, α63) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = α31 * z₁ + α32 * z₂ - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - - nlsolver.tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Prediction from embedding - nlsolver.z = z₇ = a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ - - nlsolver.tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + - btilde7 * z₇ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::Kvaerno5Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; btilde1, btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α53, α61, α62, α63) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - nlsolver.z = z₃ - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - nlsolver.z = z₄ - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - @.. broadcast = false z₅ = α51 * z₁ + α52 * z₂ + α53 * z₃ - nlsolver.z = z₅ - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - @.. broadcast = false z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - nlsolver.z = z₆ - - @.. broadcast = false tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - # Prediction is embedded method - @.. broadcast = false z₇ = a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + γ * z₆ - nlsolver.z = z₇ - - @.. broadcast = false tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₇ - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₇ / dt -end - -@muladd function perform_step!( - integrator, cache::KenCarp5ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α71, α72, α73, α74, α75, α81, α82, α83, α84, α85) = cache.tab - (; btilde1, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea43, ea51, ea53, ea54, ea61, ea63, ea64, ea65) = cache.tab - (; ea71, ea73, ea74, ea75, ea76, ea81, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb1, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde1, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.c = c3 - nlsolver.tmp = tmp - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₂ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = c4 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₂ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.c = c5 - nlsolver.tmp = tmp - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + ea63 * k3 + - ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ - tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.c = c6 - nlsolver.tmp = tmp - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₂ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + ea71 * k1 + - ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ - tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = c7 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ = z₅ - u = nlsolver.tmp + γ * z₇ - k7 = dt * f2(u, p, t + c7 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + ea81 * k1 + - ea83 * k3 + ea84 * k4 + ea85 * k5 + ea86 * k6 + ea87 * k7 - else - z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ - tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - if integrator.f isa SplitFunction - k8 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + γ * z₈ + - eb1 * k1 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ + ebtilde1 * k1 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp5Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver, step_limiter!) = cache - (; k1, k2, k3, k4, k5, k6, k7, k8) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α71, α72, α73, α74, α75, α81, α82, α83, α84, α85) = cache.tab - (; btilde1, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea43, ea51, ea53, ea54, ea61, ea63, ea64, ea65) = cache.tab - (; ea71, ea73, ea74, ea75, ea76, ea81, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb1, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde1, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - copyto!(z₂, z₁) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = a31 * z₁ + α32 * z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₃ - @.. broadcast = false u = tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 - else - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₂ - @.. broadcast = false u = tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + ea51 * k1 + - ea53 * k3 + ea54 * k4 - else - @.. broadcast = false z₅ = α51 * z₁ + α52 * z₂ - @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast = false u = tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ + - ea61 * k1 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - @.. broadcast = false z₆ = α61 * z₁ + α62 * z₂ - @.. broadcast = false tmp = uprev + a61 * z₁ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₂ - @.. broadcast = false u = tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ + ea71 * k1 + ea73 * k3 + ea74 * k4 + ea75 * k5 + - ea76 * k6 - else - @.. broadcast = false z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ - @.. broadcast = false tmp = uprev + a71 * z₁ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ .= z₅ - @.. broadcast = false u = tmp + γ * z₇ - f2(k7, u, p, t + c7 * dt) - k7 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ + ea81 * k1 + ea83 * k3 + ea84 * k4 + ea85 * k5 + - ea86 * k6 + ea87 * k7 - else - @.. broadcast = false z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ - @.. broadcast = false tmp = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₈ - if integrator.f isa SplitFunction - f2(k8, u, p, t + dt) - k8 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast = false u = uprev + a81 * z₁ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + - γ * z₈ + eb1 * k1 + eb4 * k4 + eb5 * k5 + eb6 * k6 + - eb7 * k7 + eb8 * k8 - end - - step_limiter!(u, integrator, p, t + dt) - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast = false tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + - ebtilde1 * k1 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - @.. broadcast = false tmp = btilde1 * z₁ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₈ / dt - end -end - -@muladd function perform_step!( - integrator, cache::KenCarp47ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α61, α62, α63, α71, α72, α73, α74, α75, α76) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65, ea71, ea72, ea73, ea74, ea75, ea76) = cache.tab - (; eb3, eb4, eb5, eb6, eb7) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.tmp = tmp - nlsolver.c = c3 - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₃ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.tmp = tmp - nlsolver.c = c4 - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₁ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.tmp = tmp - nlsolver.c = c5 - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + - ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.tmp = tmp - nlsolver.c = c6 - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₆ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + ea71 * k1 + ea72 * k2 + - ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ + +α76 * z₆ - tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - if integrator.f isa SplitFunction - k7 = dt * f2(u, p, t + dt) - u = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + γ * z₇ + eb3 * k3 + - eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + ebtilde6 * k6 + - ebtilde7 * k7 - else - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp47Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; k1, k2, k3, k4, k5, k6, k7) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a73, a74, a75, a76, c3, c4, c5, c6) = cache.tab - (; α31, α32, α41, α42, α43, α51, α52, α61, α62, α63, α71, α72, α73, α74, α75, α76) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65, ea71, ea72, ea73, ea74, ea75, ea76) = cache.tab - (; eb3, eb4, eb5, eb6, eb7) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= z₁ - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - #Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = a31 * z₁ + α32 * z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₃ - @.. broadcast = false u = tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ + α43 * z₃ - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₁ - @.. broadcast = false u = tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast = false z₅ = α51 * z₁ + α52 * z₂ - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast = false u = tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + - a65 * z₅ + ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + - ea65 * k5 - else - @.. broadcast = false z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₆ - @.. broadcast = false u = tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + - ea71 * k1 + ea72 * k2 + ea73 * k3 + ea74 * k4 + ea75 * k5 + - ea76 * k6 - else - @.. broadcast = false z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ + α74 * z₄ + α75 * z₅ + - α76 * z₆ - @.. broadcast = false tmp = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = 1 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₇ - if integrator.f isa SplitFunction - f2(k7, u, p, t + dt) - k7 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast = false u = uprev + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + γ * z₇ + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast = false tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + ebtilde3 * k3 + - ebtilde4 * k4 + ebtilde5 * k5 + ebtilde6 * k6 + - ebtilde7 * k7 - else - @.. broadcast = false tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₇ / dt - end -end - -@muladd function perform_step!( - integrator, cache::KenCarp58ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a83, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α63, α71, α72, α73, α81, α82, α83, α84, α85, α86, α87) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; ea71, ea72, ea73, ea74, ea75, ea76, ea81, ea82, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb3, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - k6 = nothing - k7 = nothing - k8 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - # calculate W - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction - # Explicit tableau is not FSAL - # Make this not compute on repeat - z₁ = dt .* f(uprev, p, t) - else - # FSAL Step 1 - z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Add extrapolation choice - - nlsolver.z = z₂ = z₁ - - tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - k1 = dt * integrator.fsalfirst - z₁ - tmp += ea21 * k1 - end - nlsolver.tmp = tmp - nlsolver.c = 2γ - - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ = z₂ - u = nlsolver.tmp + γ * z₂ - k2 = dt * f2(u, p, t + 2γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - z₃ = α31 * z₁ + α32 * z₂ - tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - nlsolver.c = c3 - nlsolver.tmp = tmp - - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ = z₁ - u = nlsolver.tmp + γ * z₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - z₄ = α41 * z₁ + α42 * z₂ - tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - nlsolver.c = c4 - nlsolver.tmp = tmp - - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ = z₂ - u = nlsolver.tmp + γ * z₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + ea51 * k1 + ea52 * k2 + - ea53 * k3 + ea54 * k4 - else - z₅ = α51 * z₁ + α52 * z₂ - tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - nlsolver.c = c5 - nlsolver.tmp = tmp - - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ = z₃ - u = nlsolver.tmp + γ * z₅ - k5 = dt * f2(u, p, t + c5 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ + ea61 * k1 + - ea62 * k2 + ea63 * k3 + ea64 * k4 + ea65 * k5 - else - z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - nlsolver.c = c6 - nlsolver.tmp = tmp - - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ = z₃ - u = nlsolver.tmp + γ * z₆ - k6 = dt * f2(u, p, t + c6 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ + - ea71 * k1 + ea72 * k2 + ea73 * k3 + ea74 * k4 + ea75 * k5 + ea76 * k6 - else - z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ - tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - nlsolver.c = c7 - nlsolver.tmp = tmp - - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ = z₇ - u = nlsolver.tmp + γ * z₇ - k7 = dt * f2(u, p, t + c7 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + ea81 * k1 + - ea82 * k2 + ea83 * k3 + ea84 * k4 + ea85 * k5 + ea86 * k6 + ea87 * k7 - else - z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ + α86 * z₆ + α87 * z₇ - tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - nlsolver.c = 1 - nlsolver.tmp = tmp - - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - if integrator.f isa SplitFunction - k8 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - u = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + γ * z₈ + - eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ + ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + - btilde8 * z₈ - end - if isnewton(nlsolver) && alg.smooth_est # From Shampine - integrator.stats.nsolve += 1 - est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) - else - est = tmp - end - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::KenCarp58Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver) = cache - (; k1, k2, k3, k4, k5, k6, k7, k8) = cache - (; tmp) = nlsolver - (; γ, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a83, a84, a85, a86, a87, c3, c4, c5, c6, c7) = cache.tab - (; α31, α32, α41, α42, α51, α52, α61, α62, α63, α71, α72, α73, α81, α82, α83, α84, α85, α86, α87) = cache.tab - (; btilde3, btilde4, btilde5, btilde6, btilde7, btilde8) = cache.tab - (; ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, ea61, ea62, ea63, ea64, ea65) = cache.tab - (; ea71, ea72, ea73, ea74, ea75, ea76, ea81, ea82, ea83, ea84, ea85, ea86, ea87) = cache.tab - (; eb3, eb4, eb5, eb6, eb7, eb8) = cache.tab - (; ebtilde3, ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8) = cache.tab - alg = unwrap_alg(integrator, true) - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - # precalculations - γdt = γ * dt - - markfirststage!(nlsolver) - - ##### Step 1 - - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - # Explicit tableau is not FSAL - # Make this not compute on repeat - f(z₁, integrator.uprev, p, integrator.t) - z₁ .*= dt - else - # FSAL Step 1 - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - ##### Step 2 - - # TODO: Allow other choices here - z₂ .= z₁ - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - - if integrator.f isa SplitFunction - # This assumes the implicit part is cheaper than the explicit part - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - if integrator.f isa SplitFunction - z₃ .= z₂ - @.. broadcast = false u = tmp + γ * z₂ - f2(k2, u, p, t + 2γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - # Guess is from Hermite derivative on z₁ and z₂ - @.. broadcast = false z₃ = α31 * z₁ + α32 * z₂ - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - end - nlsolver.z = z₃ - - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - if integrator.f isa SplitFunction - z₄ .= z₁ - @.. broadcast = false u = tmp + γ * z₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast = false z₄ = α41 * z₁ + α42 * z₂ - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ - - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - if integrator.f isa SplitFunction - z₅ .= z₂ - @.. broadcast = false u = tmp + γ * z₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast = false z₅ = α51 * z₁ + α52 * z₂ - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ - - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - if integrator.f isa SplitFunction - z₆ .= z₃ - @.. broadcast = false u = tmp + γ * z₅ - f2(k5, u, p, t + c5 * dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + - a65 * z₅ + ea61 * k1 + ea62 * k2 + ea63 * k3 + ea64 * k4 + - ea65 * k5 - else - @.. broadcast = false z₆ = α61 * z₁ + α62 * z₂ + α63 * z₃ - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - end - nlsolver.z = z₆ - - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - if integrator.f isa SplitFunction - z₇ .= z₃ - @.. broadcast = false u = tmp + γ * z₆ - f2(k6, u, p, t + c6 * dt) - k6 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + - a75 * z₅ + a76 * z₆ + ea71 * k1 + ea72 * k2 + ea73 * k3 + - ea74 * k4 + ea75 * k5 + ea76 * k6 - else - @.. broadcast = false z₇ = α71 * z₁ + α72 * z₂ + α73 * z₃ - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + - a75 * z₅ + a76 * z₆ - end - nlsolver.z = z₇ - - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - if integrator.f isa SplitFunction - z₈ .= z₇ - @.. broadcast = false u = tmp + γ * z₇ - f2(k7, u, p, t + c7 * dt) - k7 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ + ea81 * k1 + ea82 * k2 + ea83 * k3 + ea84 * k4 + - ea85 * k5 + ea86 * k6 + ea87 * k7 - else - @.. broadcast = false z₈ = α81 * z₁ + α82 * z₂ + α83 * z₃ + α84 * z₄ + α85 * z₅ + - α86 * z₆ + α87 * z₇ - @.. broadcast = false tmp = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ - end - nlsolver.z = z₈ - - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₈ - if integrator.f isa SplitFunction - f2(k8, u, p, t + dt) - k8 .*= dt - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - @.. broadcast = false u = uprev + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + a87 * z₇ + - γ * z₈ + eb3 * k3 + eb4 * k4 + eb5 * k5 + eb6 * k6 + - eb7 * k7 + eb8 * k8 - end - - ################################### Finalize - - if integrator.opts.adaptive - if integrator.f isa SplitFunction - @.. broadcast = false tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + - ebtilde3 * k3 + ebtilde4 * k4 + ebtilde5 * k5 + - ebtilde6 * k6 + ebtilde7 * k7 + ebtilde8 * k8 - else - @.. broadcast = false tmp = btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - end - - if isnewton(nlsolver) && alg.smooth_est # From Shampine - est = nlsolver.cache.dz - - linres = dolinsolve( - integrator, nlsolver.cache.linsolve; b = _vec(tmp), - linu = _vec(est) - ) - - integrator.stats.nsolve += 1 - else - est = tmp - end - calculate_residuals!( - atmp, est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - if integrator.f isa SplitFunction integrator.f(integrator.fsallast, u, p, t + dt) else - @.. broadcast = false integrator.fsallast = z₈ / dt + @..integrator.fsallast = z₄ / dt end end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl index 978120e3b86..a039e07a03f 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl @@ -825,343 +825,3 @@ function alg_cache( return Hairer4Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, atmp, nlsolver, tab) end -@cache mutable struct ESDIRK54I8L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK54I8L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK54I8L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return ESDIRK54I8L2SACache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver, - tab - ) -end - -mutable struct ESDIRK54I8L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK54I8L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK54I8L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ESDIRK54I8L2SAConstantCache(nlsolver, tab) -end - -@cache mutable struct ESDIRK436L2SA2Cache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK436L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK436L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return ESDIRK436L2SA2Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver, - tab - ) -end - -mutable struct ESDIRK436L2SA2ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK436L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK436L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ESDIRK436L2SA2ConstantCache(nlsolver, tab) -end - -@cache mutable struct ESDIRK437L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK437L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK437L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return ESDIRK437L2SACache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, - tab - ) -end - -mutable struct ESDIRK437L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK437L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK437L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ESDIRK437L2SAConstantCache(nlsolver, tab) -end - -@cache mutable struct ESDIRK547L2SA2Cache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK547L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, - ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK547L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return ESDIRK547L2SA2Cache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver, - tab - ) -end - -mutable struct ESDIRK547L2SA2ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK547L2SA2, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK547L2SA2Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ESDIRK547L2SA2ConstantCache(nlsolver, tab) -end - -@cache mutable struct ESDIRK659L2SACache{uType, rateType, uNoUnitsType, Tab, N} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - z₆::uType - z₇::uType - z₈::uType - z₉::uType - atmp::uNoUnitsType - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK659L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, - dt, reltol, - p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK659L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = zero(u) - z₆ = zero(u) - z₇ = zero(u) - z₈ = zero(u) - z₉ = nlsolver.z - atmp = similar(u, uEltypeNoUnits) - recursivefill!(atmp, false) - - return ESDIRK659L2SACache( - u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, z₉, atmp, - nlsolver, tab - ) -end - -mutable struct ESDIRK659L2SAConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ESDIRK659L2SA, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where - {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ESDIRK659L2SATableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ESDIRK659L2SAConstantCache(nlsolver, tab) -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl index 283b8817a77..2fb65469656 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl @@ -2165,1068 +2165,3 @@ end @.. broadcast = false integrator.fsallast = z₅ / dt end -@muladd function perform_step!( - integrator, cache::ESDIRK54I8L2SAConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, - ) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = z₈ = zero(z₇) - - nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ - nlsolver.c = 1 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₈ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₈ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK54I8L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, - ) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = fill!(z₈, zero(eltype(u))) - - @.. broadcast = false nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + - a85 * z₅ + a86 * z₆ + a87 * z₇ - nlsolver.c = oneunit(nlsolver.c) - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₈ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ - calculate_residuals!( - atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₈ / dt - return -end - -@muladd function perform_step!( - integrator, cache::ESDIRK436L2SA2ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - c3, c4, c5, c6, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, - ) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₆ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₆ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK436L2SA2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - c3, c4, c5, c6, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, - ) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₆ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ - calculate_residuals!( - atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₆ / dt - return -end - -@muladd function perform_step!( - integrator, cache::ESDIRK437L2SAConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, - ) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK437L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, - ) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - calculate_residuals!( - atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₇ / dt - return -end - -@muladd function perform_step!( - integrator, cache::ESDIRK547L2SA2ConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, - ) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₇ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK547L2SA2Cache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - c3, c4, c5, c6, c7, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, - ) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₇ - - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + btilde6 * z₆ + btilde7 * z₇ - calculate_residuals!( - atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₇ / dt - return -end - -@muladd function perform_step!( - integrator, cache::ESDIRK659L2SAConstantCache, - repeat_step = false - ) - (; t, dt, uprev, u, f, p) = integrator - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - a94, a95, a96, a97, a98, - c3, c4, c5, c6, c7, c8, c9, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, btilde9, - ) = cache.tab - nlsolver = cache.nlsolver - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - # TODO: Add extrapolation for guess - - ##### Step 1 - - z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation choice - nlsolver.z = z₂ = zero(z₁) - - nlsolver.tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 3 - - nlsolver.z = z₃ = zero(z₂) - - nlsolver.tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - nlsolver.z = z₄ = zero(z₃) - - nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = z₅ = zero(z₄) - - nlsolver.tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = z₆ = zero(z₅) - - nlsolver.tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = z₇ = zero(z₆) - - nlsolver.tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - nlsolver.z = z₈ = zero(z₇) - - nlsolver.tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + a86 * z₆ + - a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 9 - nlsolver.z = z₉ = zero(z₈) - - nlsolver.tmp = uprev + a94 * z₄ + a95 * z₅ + a96 * z₆ + a97 * z₇ + a98 * z₈ - nlsolver.c = c9 - z₉ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - u = nlsolver.tmp + γ * z₉ - - ################################### Finalize - - if integrator.opts.adaptive - est = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + btilde9 * z₉ - atmp = calculate_residuals( - est, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - integrator.fsallast = z₉ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.u = u - return -end - -@muladd function perform_step!(integrator, cache::ESDIRK659L2SACache, repeat_step = false) - (; t, dt, uprev, u, f, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, z₆, z₇, z₈, z₉, atmp, nlsolver) = cache - (; tmp) = nlsolver - (; - γ, - a31, a32, - a41, a42, a43, - a51, a52, a53, a54, - a61, a62, a63, a64, a65, - a71, a72, a73, a74, a75, a76, - a81, a82, a83, a84, a85, a86, a87, - a94, a95, a96, a97, a98, - c3, c4, c5, c6, c7, c8, c9, - btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7, btilde8, btilde9, - ) = cache.tab - alg = unwrap_alg(integrator, true) - - # precalculations - γdt = γ * dt - markfirststage!(nlsolver) - - ##### Step 1 - - @.. broadcast = false z₁ = dt * integrator.fsalfirst - - ##### Step 2 - - # TODO: Add extrapolation for guess - z₂ .= zero(eltype(u)) - nlsolver.z = z₂ - - @.. broadcast = false tmp = uprev + γ * z₁ - nlsolver.c = 2γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - ################################## Solve Step 3 - - nlsolver.z = fill!(z₃, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a31 * z₁ + a32 * z₂ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 4 - - # Use constant z prediction - nlsolver.z = fill!(z₄, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 5 - - nlsolver.z = fill!(z₅, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a51 * z₁ + a52 * z₂ + a53 * z₃ + a54 * z₄ - nlsolver.c = c5 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 6 - - nlsolver.z = fill!(z₆, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a61 * z₁ + a62 * z₂ + a63 * z₃ + a64 * z₄ + a65 * z₅ - nlsolver.c = c6 - z₆ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 7 - - nlsolver.z = fill!(z₇, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a71 * z₁ + a72 * z₂ + a73 * z₃ + a74 * z₄ + a75 * z₅ + - a76 * z₆ - nlsolver.c = c7 - z₇ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 8 - - nlsolver.z = fill!(z₈, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a81 * z₁ + a82 * z₂ + a83 * z₃ + a84 * z₄ + a85 * z₅ + - a86 * z₆ + a87 * z₇ - nlsolver.c = c8 - z₈ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - ################################## Solve Step 9 - - nlsolver.z = fill!(z₉, zero(eltype(u))) - - @.. broadcast = false tmp = uprev + a94 * z₄ + a95 * z₅ + a96 * z₆ + a97 * z₇ + a98 * z₈ - nlsolver.c = c9 - z₉ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - @.. broadcast = false u = tmp + γ * z₉ - ################################### Finalize - - if integrator.opts.adaptive - @.. broadcast = false tmp = btilde1 * z₁ + btilde2 * z₂ + btilde3 * z₃ + btilde4 * z₄ + - btilde5 * z₅ + - btilde6 * z₆ + btilde7 * z₇ + btilde8 * z₈ + btilde9 * z₉ - calculate_residuals!( - atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end - - @.. broadcast = false integrator.fsallast = z₉ / dt - return -end