From 2aa615c9094916750ad5b48e95053e4e7abcb222 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:07:41 -0600 Subject: [PATCH 01/15] add disco to core --- .../src/OrdinaryDiffEqCore.jl | 1 + lib/OrdinaryDiffEqCore/src/disco.jl | 77 ++++++ .../src/integrators/controllers.jl | 6 + .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 3 +- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 256 ++++++++++++++++++ 6 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 lib/OrdinaryDiffEqCore/src/disco.jl create mode 100644 lib/OrdinaryDiffEqCore/test/disco_tests.jl diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index fbfd8a743b7..29acc5f53c2 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -162,6 +162,7 @@ include("cache_utils.jl") include("initialize_dae.jl") include("perform_step/composite_perform_step.jl") +include("disco.jl") include("dense/generic_dense.jl") diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl new file mode 100644 index 00000000000..80bdfe143c4 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -0,0 +1,77 @@ +function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to test + breakpointθ = find_discontinuity(u, uprev, integrator, cache) + dt = integrator.dt + t = integrator.t + if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 + #println("Discontinuity detected at t = ", t + breakpointθ * dt) + integrator.dt = breakpointθ * dt + integrator.disco_dt_set = true + end +end + +function find_discontinuity(u, uprev, integrator, cache) + cb = integrator.opts.callback + cb === nothing && return -1 + isempty(cb.continuous_callbacks) && return -1 + + disco_exists = false; + for i in cb.continuous_callbacks + if (i.is_discontinuity) + disco_exists = true + break + end + end + !disco_exists && return -1 + p = integrator.p + t = integrator.t + dt = integrator.dt + breakpointθ = -one(dt) + prob = nothing + for i in cb.continuous_callbacks + if (!(i.is_discontinuity)) + continue + end + out_prev = nothing + out_curr = nothing + is_inplace = DiffEqBase.isinplace(i.condition, 4) + if is_inplace + out_prev = similar(u) + i.condition(out_prev, uprev, t, integrator) + out_curr = similar(u) + i.condition(out_curr, u, t + dt, integrator) + is_inplace = true + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + is_inplace = false + end + for (idx, (f0, f1)) in enumerate(zip(out_prev, out_curr)) + if (f0 * f1 < zero(f0)) + function zero_func(θ, p) + u₁ = similar(u) + _ode_interpolant!(u₁, θ, dt, uprev, u, integrator.k, cache, + nothing, Val{0}, nothing) + + if is_inplace + out = similar(u) + i.condition(out, u₁, t + θ * dt, integrator) + else + out = i.condition(u₁, t + θ * dt, integrator) + end + out[idx] + end + if prob === nothing + prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) + else + prob = remake(prob; f=zero_func) + end + sol = solve(prob; bracket=[zero(dt), one(dt)]) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + end + breakpointθ +end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 5ad0824b66c..6325a9bb17c 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -948,6 +948,12 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache + + if (integrator.disco_dt_set) + println("using fixed dt from discontinuity handling") + integrator.disco_dt_set = false + return integrator.dt + end return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index d97636952bd..6234de5143b 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -108,6 +108,7 @@ mutable struct ODEIntegrator{ dtcache::tType dtchangeable::Bool dtpropose::tType + disco_dt_set::Bool tdir::tdirType eigen_est::eigenType EEst::EEstT diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 976402705f8..b9ce19f7622 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -32,6 +32,7 @@ function SciMLBase.__init( save_everystep = isempty(saveat), save_on = true, save_discretes = true, + disco_dt_set = false, save_start = save_everystep || isempty(saveat) || saveat isa Number || prob.tspan[1] in saveat, save_end = nothing, @@ -667,7 +668,7 @@ function SciMLBase.__init( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, - dtpropose, tdir, eigen_est, EEst, + dtpropose, disco_dt_set, tdir, eigen_est, EEst, # TODO vvv remove these QT(qoldinit), q11, erracc, dtacc, diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl new file mode 100644 index 00000000000..1a26db54328 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -0,0 +1,256 @@ +using OrdinaryDiffEqCore +using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqRosenbrock +using OrdinaryDiffEqBDF + +#test example discontinuous at u = 1 +f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] +u0 = [0.1] +tspan = (0.0, 1.5) +prob = ODEProblem(f, u0, tspan) + +#define callback +condition(u, t, integrator) = u[1] - 1 +function affect!(integrator) + integrator.u[1] += 10 + println("Callback fired at t = ", integrator.t) +end +cb = ContinuousCallback(condition, affect!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +#fixed order solve +sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) + +rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) + +bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) + +#two discontinuity functions +function f(u, p, t) + if u[1] < 1 + [2u[1]] # region 1: grows to hit u = 1 + elseif u[1] < 2 + [u[1] + 0.2] # region 2: continues increasing to hit u = 2 + else + [-4u[1] + 12] # region 3: after 2, moves toward u ≈ 3 + end +end + +u0 = [0.1] +tspan = (0.0, 2.5) +prob = ODEProblem(f, u0, tspan) + +#define callbacks +condition1(u, t, integrator) = u[1] - 1 +function affect1!(integrator) + #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) + +condition2(u, t, integrator) = u[1] - 2 +function affect2!(integrator) + #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) +cb = CallbackSet(cb1, cb2) + +#disco solve +sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +#fixed order solve +sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) + +rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) + +bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) + + +# multiple exponential regions with sharp transitions +function f_multi_exp!(du, u, p, t) + if u[1] < 0.3 + du[1] = 3 * exp(3 * u[1]) # very steep exponential + elseif u[1] < 0.8 + du[1] = exp(u[1]) # slower exponential + else + du[1] = u[1] # linear + end +end + +u0_multi = [0.05] +tspan_multi = (0.0, 1.5) +prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) + +#define callbacks +cond_multi_1(u, t, integrator) = u[1] - 0.3 +function affect_multi_1!(integrator) + println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) + +cond_multi_2(u, t, integrator) = u[1] - 0.8 +function affect_multi_2!(integrator) + println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) +cb_multi = CallbackSet(cb_multi_1, cb_multi_2) + +#disco solve +sol_disco = solve(prob_multi, RadauIIA5(is_disco = true); callback=cb_multi, reltol=1e-7, abstol=1e-9) +#fixed order solve +sol_no_disco = solve(prob_multi, RadauIIA5(is_disco = false); callback=cb_multi, reltol = 1e-7, abstol = 1e-9) + +# 2D system with exponential coupling and discontinuity +function f_2d_exp!(du, u, p, t) + if u[1] + u[2] < 1.0 + du[1] = 2 * exp(u[1]) - u[2] + du[2] = -3 * u[1] + 4 * exp(u[2]) + else + du[1] = u[1] + du[2] = u[2] + end +end + +u0_2d = [0.1, 0.2] +tspan_2d = (0.0, 2.0) +prob_2d = ODEProblem(f_2d_exp!, u0_2d, tspan_2d) + +#define callback +cond_2d(u, t, integrator) = u[1] + u[2] - 1.0 +function affect_2d!(integrator) + println("2D exponential discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") + @test 0.98 < integrator.u[1] + integrator.u[2] < 1.02 +end +cb_2d = ContinuousCallback(cond_2d, affect_2d!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob_2d, RadauIIA5(is_disco = true); callback=cb_2d, reltol=1e-8, abstol=1e-10) +#fixed order solve +sol_no_disco = solve(prob_2d, RadauIIA5(is_disco = false); callback=cb_2d, reltol = 1e-8, abstol = 1e-10) + +# very stiff discontinuous system +function f_stiff_disc!(du, u, p, t) + λ = p[1] # stiffness parameter + if u[1] < 0.5 + du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing + else + du[1] = u[1] + end +end + +u0_stiff = [0.1] +tspan_stiff = (0.0, 3.0) +prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) + +#define callback +cond_stiff(u, t, integrator) = u[1] - 0.5 +function affect_stiff!(integrator) + println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) + +#disco solve +sol_disco = solve(prob_stiff, RadauIIA5(is_disco = true); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +#fixed order solve +sol_no_disco = solve(prob_stiff, RadauIIA5(is_disco = false); callback=cb_stiff, reltol = 1e-9, abstol = 1e-11) + +# multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) +function f_many_disc!(du, u, p, t) + du[1] = u[1] + 1 # simple linear growth +end + +u0_many = [0.0] +tspan_many = (0.0, 1.0) +prob_many = ODEProblem(f_many_disc!, u0_many, tspan_many) + +# create 5 discontinuities spaced 1e-6 apart +disc_values = [0.1 + i * 1e-6 for i = 0:4] + +# define callbacks for each discontinuity +cbs_many = [] +for (i, disc_val) in enumerate(disc_values) + local cond_func(u, t, integrator) = u[1] - disc_val + function affect_func!(integrator) + println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") + end + push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) +end +cb_many = CallbackSet(cbs_many...) + +#disco solve +sol_disco = solve(prob_many, RadauIIA5(is_disco = true); callback=cb_many, reltol=1e-10, abstol=1e-12) +#fixed order solve +sol_no_disco = solve(prob_many, RadauIIA5(is_disco = false); callback=cb_many, reltol=1e-10, abstol=1e-12) + +# discontinuous DAE with mass matrix +# System: M * du/dt = f(u, p, t) +# du[1]/dt = u[2] - u[1] +# 0 = u[1] + u[2] - 1 (algebraic constraint) +function f_dae_disc!(du, u, p, t) + if u[1] < 0.5 + du[1] = 2 * u[2] - u[1] + du[2] = u[1] + u[2] - 1 # algebraic constraint + else + du[1] = -u[1] + u[2] + du[2] = u[1] + u[2] - 1 # algebraic constraint + end +end + +u0_dae = [0.2, 0.8] # consistent with constraint u[1] + u[2] = 1 +tspan_dae = (0.0, 2.0) + +M_dae = [1.0 0.0; 0.0 0.0] + +f_dae_func = ODEFunction(f_dae_disc!; mass_matrix=M_dae) +prob_dae = ODEProblem(f_dae_func, u0_dae, tspan_dae) + +cond_dae(u, t, integrator) = u[1] - 0.5 +function affect_dae!(integrator) + #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") +end +cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) + + +radau_no_disco = solve(prob_dae, RadauIIA5(is_disco = false); callback=cb_dae, reltol=1e-8, abstol=1e-10) + #83.500 μs (769 allocations: 35.72 KiB) +radau_disco = solve(prob_dae, RadauIIA5(is_disco = true); callback=cb_dae, reltol=1e-8, abstol=1e-10) + # 119.417 μs (1273 allocations: 55.42 KiB) +rodas_no_disco = solve(prob_dae, Rodas5P(); callback = cb_dae, reltol = 1e-6) +#= SciMLBase.DEStats +Number of function 1 evaluations: 312 +Number of function 2 evaluations: 0 +Number of W matrix evaluations: 34 +Number of linear solves: 272 +Number of Jacobians created: 19 +Number of nonlinear solver iterations: 0 +Number of nonlinear solver convergence failures: 0 +Number of fixed-point solver iterations: 0 +Number of fixed-point solver convergence failures: 0 +Number of rootfind condition calls: 213 +Number of accepted steps: 19 +Number of rejected steps: 15 =# +# 98.167 μs (550 allocations: 26.92 KiB) +rodas_disco = solve(prob_dae, Rodas5P(is_disco = true); callback = cb_dae, reltol = 1e-6) +#= SciMLBase.DEStats +Number of function 1 evaluations: 312 +Number of function 2 evaluations: 0 +Number of W matrix evaluations: 34 +Number of linear solves: 272 +Number of Jacobians created: 19 +Number of nonlinear solver iterations: 0 +Number of nonlinear solver convergence failures: 0 +Number of fixed-point solver iterations: 0 +Number of fixed-point solver convergence failures: 0 +Number of rootfind condition calls: 213 +Number of accepted steps: 19 +Number of rejected steps: 15 =# +# 97.541 μs (550 allocations: 26.92 KiB) +bdf_no_disco = solve(prob_dae, FBDF(); callback = cb_dae, reltol = 1e-6) +bdf_disco = solve(prob_dae, FBDF(is_disco = true); callback = cb_dae, reltol = 1e-6) \ No newline at end of file From f485bdc59b13261a5eb7e958e13be9f6eea3b8bf Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:17:16 -0600 Subject: [PATCH 02/15] radau version --- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 5 ++--- lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl | 2 +- lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 6 ++++-- lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl | 8 ++++++++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 1a26db54328..3b27868de8c 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,4 +1,3 @@ -using OrdinaryDiffEqCore using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock using OrdinaryDiffEqBDF @@ -13,15 +12,15 @@ prob = ODEProblem(f, u0, tspan) condition(u, t, integrator) = u[1] - 1 function affect!(integrator) integrator.u[1] += 10 - println("Callback fired at t = ", integrator.t) end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) #disco solve sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +# 291.292 μs (8449 allocations: 266.47 KiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) - +# 335.417 μs (10008 allocations: 311.08 KiB) rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) diff --git a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl index a0cd04b69bf..d77dc456c45 100644 --- a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl +++ b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl @@ -18,7 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, fac_default_gamma, get_current_adaptive_order, get_fsalfirstlast, isfirk, generic_solver_docstring, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier + _process_AD_choice, LinearAliasSpecifier, set_discontinuity using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester isfirk, generic_solver_docstring using SciMLOperators: AbstractSciMLOperator diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index aa433958379..3ee4ecaa422 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -92,6 +92,7 @@ struct RadauIIA5{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD + is_disco::Bool end function RadauIIA5(; @@ -102,7 +103,7 @@ function RadauIIA5(; extrapolant = :dense, fast_convergence_cutoff = 1 // 5, new_W_γdt_cutoff = 1 // 5, controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true, - step_limiter! = trivial_limiter! + step_limiter! = trivial_limiter!, is_disco = false ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -122,7 +123,8 @@ function RadauIIA5(; new_W_γdt_cutoff, controller, step_limiter!, - AD_choice + AD_choice, + is_disco ) end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 499b345d13c..ce535649b5a 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -675,6 +675,10 @@ end integrator.k[4] = z2 integrator.k[5] = z3 end + else + if alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end integrator.fsallast = f(u, p, t + dt) @@ -952,6 +956,10 @@ end integrator.k[4] .= z2 integrator.k[5] .= z3 end + else + if alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end f(fsallast, u, p, t + dt) From cbd2b1c789a9391b1ffab6760ecbf6e9221eb1d2 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 23 Feb 2026 15:50:14 -0600 Subject: [PATCH 03/15] bdf and rodas --- lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 7 ++++--- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 7 +++++++ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 12 +++++++----- .../src/OrdinaryDiffEqRosenbrock.jl | 2 +- lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl | 7 ++++--- .../src/rosenbrock_perform_step.jl | 7 +++++++ 7 files changed, 31 insertions(+), 13 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index f6e404018da..50c5eb33fb8 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index ad0aba5f193..6aa128a6592 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -575,6 +575,7 @@ struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD + is_disco::Bool end function FBDF(; @@ -583,7 +584,7 @@ function FBDF(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing, tol = nothing, - extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter! + extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!, is_disco = false ) where {MO} AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -594,7 +595,7 @@ function FBDF(; typeof(κ), typeof(tol), typeof(step_limiter!), }( max_order, linsolve, nlsolve, precs, κ, tol, extrapolant, - controller, step_limiter!, AD_choice + controller, step_limiter!, AD_choice, is_disco ) end @@ -841,4 +842,4 @@ function DFBDF(; ) end -@truncate_stacktrace DFBDF +@truncate_stacktrace DFBDF \ No newline at end of file diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index 59deace1b22..7e9cd3f3031 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1276,6 +1276,9 @@ function perform_step!( integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) + set_discontinuity(u, uprev, integrator, cache) + end terk = estimate_terk(integrator, cache, k + 1, Val(max_order), u) fd_weights = calc_finite_difference_weights(ts_tmp, tdt, k, Val(max_order)) @@ -1483,6 +1486,10 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) + set_discontinuity(u, uprev, integrator, cache) + end + estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( atmp, _vec(terk_tmp), _vec(uprev), _vec(u), abstol, reltol, diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 3b27868de8c..0f69ed6d3db 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -15,20 +15,22 @@ function affect!(integrator) end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) -#disco solve sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) # 291.292 μs (8449 allocations: 266.47 KiB) -#fixed order solve sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) # 335.417 μs (10008 allocations: 311.08 KiB) rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) - +# 410.291 μs (16828 allocations: 594.05 KiB) rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) - +# 483.792 μs (17729 allocations: 639.31 KiB) bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) - +# 245.917 μs (20703 allocations: 665.16 KiB) bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +# 269.333 μs (20477 allocations: 663.80 KiB) +@profview for i in 1:1000 + solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +end #two discontinuity functions function f(u, p, t) if u[1] < 1 diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 9ef1b32e2a7..7017e379fad 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_adaptive_order, isWmethod, isfsal, _un calculate_residuals, has_stiff_interpolation, ODEIntegrator, resize_non_user_cache!, _ode_addsteps!, full_cache, DerivativeOrderNotPossibleError, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, copyat_or_push! + _process_AD_choice, LinearAliasSpecifier, copyat_or_push!, set_discontinuity using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools: namify using MacroTools: @capture diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index cc5a11dedd3..c08af7f8f8c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -26,13 +26,14 @@ for (Alg, desc, refs, is_W) in [ step_limiter!::StepLimiter stage_limiter!::StageLimiter autodiff::AD + is_disco::Bool end function $Alg(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!, - stage_limiter! = trivial_limiter! + stage_limiter! = trivial_limiter!, is_disco = false ) AD_choice, chunk_size, diff_type = _process_AD_choice( autodiff, chunk_size, diff_type @@ -41,10 +42,10 @@ for (Alg, desc, refs, is_W) in [ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac), typeof(step_limiter!), - typeof(stage_limiter!), + typeof(stage_limiter!) }( linsolve, precs, step_limiter!, - stage_limiter!, AD_choice + stage_limiter!, AD_choice, is_disco ) end end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 0ad836ac446..038c06fa01c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1384,6 +1384,10 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + + if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end if integrator.opts.calck @@ -1524,6 +1528,9 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) + if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco + set_discontinuity(u, uprev, integrator, cache) + end end if integrator.opts.calck From aff9da0d0c899f67cbc4dd7bdaf86df4fd28c853 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sun, 1 Mar 2026 19:38:57 -0600 Subject: [PATCH 04/15] refactor disco into controllers --- .../src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 5 +-- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 6 --- lib/OrdinaryDiffEqCore/src/disco.jl | 18 ++------- .../src/integrators/controllers.jl | 37 ++++++++++++++++++- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 29 ++++++++++----- .../src/OrdinaryDiffEqFIRK.jl | 2 +- lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 4 +- .../src/firk_perform_step.jl | 8 ---- .../src/OrdinaryDiffEqRosenbrock.jl | 2 +- .../src/algorithms.jl | 5 +-- .../src/rosenbrock_perform_step.jl | 7 ---- 12 files changed, 68 insertions(+), 57 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index 50c5eb33fb8..f6e404018da 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity + _ode_addsteps!, DerivativeOrderNotPossibleError using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index 6aa128a6592..d6433237802 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -575,7 +575,6 @@ struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD - is_disco::Bool end function FBDF(; @@ -584,7 +583,7 @@ function FBDF(; diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing, tol = nothing, - extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!, is_disco = false + extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter! ) where {MO} AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -595,7 +594,7 @@ function FBDF(; typeof(κ), typeof(tol), typeof(step_limiter!), }( max_order, linsolve, nlsolve, precs, κ, tol, extrapolant, - controller, step_limiter!, AD_choice, is_disco + controller, step_limiter!, AD_choice ) end diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index 7e9cd3f3031..bee908fe82d 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1276,9 +1276,6 @@ function perform_step!( integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) - set_discontinuity(u, uprev, integrator, cache) - end terk = estimate_terk(integrator, cache, k + 1, Val(max_order), u) fd_weights = calc_finite_difference_weights(ts_tmp, tdt, k, Val(max_order)) @@ -1486,9 +1483,6 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if (integrator.EEst > one(integrator.EEst) && integrator.alg.is_disco) - set_discontinuity(u, uprev, integrator, cache) - end estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 80bdfe143c4..20698ade0f1 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -4,24 +4,16 @@ function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to te t = integrator.t if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 #println("Discontinuity detected at t = ", t + breakpointθ * dt) - integrator.dt = breakpointθ * dt - integrator.disco_dt_set = true + return breakpointθ * dt end + return -1 end function find_discontinuity(u, uprev, integrator, cache) + println("Finding discontinuity...") cb = integrator.opts.callback cb === nothing && return -1 isempty(cb.continuous_callbacks) && return -1 - - disco_exists = false; - for i in cb.continuous_callbacks - if (i.is_discontinuity) - disco_exists = true - break - end - end - !disco_exists && return -1 p = integrator.p t = integrator.t dt = integrator.dt @@ -49,9 +41,7 @@ function find_discontinuity(u, uprev, integrator, cache) if (f0 * f1 < zero(f0)) function zero_func(θ, p) u₁ = similar(u) - _ode_interpolant!(u₁, θ, dt, uprev, u, integrator.k, cache, - nothing, Val{0}, nothing) - + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) if is_inplace out = similar(u) i.condition(out, u₁, t + θ * dt, integrator) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 6325a9bb17c..a9f0cb211c8 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -169,6 +169,11 @@ end function step_reject_controller!(integrator, controller::IController, alg) (; qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = qold end @@ -241,6 +246,11 @@ end function step_reject_controller!(integrator, cache::IControllerCache, alg) @assert cache.dtreject ≈ integrator.qold "Controller cache went out of sync with time stepping logic." + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = cache.dtreject # TODO this does not look right. end @@ -320,6 +330,11 @@ end function step_reject_controller!(integrator, controller::PIController, alg) (; q11) = integrator (; qmin, gamma) = integrator.opts + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -423,6 +438,11 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -599,6 +619,11 @@ function step_accept_controller!(integrator, controller::PIDController, alg, dt_ end function step_reject_controller!(integrator, controller::PIDController, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= integrator.qold end @@ -730,6 +755,11 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= cache.dt_factor end @@ -841,6 +871,12 @@ end function step_reject_controller!(integrator, controller::PredictiveController, alg) (; dt, success_iter, qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end @@ -948,7 +984,6 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache - if (integrator.disco_dt_set) println("using fixed dt from discontinuity handling") integrator.disco_dt_set = false diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 0f69ed6d3db..5f3b80afa7c 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,6 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock -using OrdinaryDiffEqBDF +using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner #test example discontinuous at u = 1 f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] @@ -14,23 +13,35 @@ function affect!(integrator) integrator.u[1] += 10 end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) -sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 291.292 μs (8449 allocations: 266.47 KiB) -sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 335.417 μs (10008 allocations: 311.08 KiB) -rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) + +rodas_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) # 410.291 μs (16828 allocations: 594.05 KiB) -rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) +rodas_no_disco = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) # 483.792 μs (17729 allocations: 639.31 KiB) -bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) + +bdf_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) # 245.917 μs (20703 allocations: 665.16 KiB) -bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +bdf_no_disco = solve(prob, FBDF(); callback = cb2, reltol = 1e-6) # 269.333 μs (20477 allocations: 663.80 KiB) +tsit_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# same either way for some reason? check about this +tsit_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) + +vern_disco = solve(prob, Vern7(); callback = cb, reltol = 1e-6) +# 111.125 μs (15629 allocations: 493.66 KiB) +vern_no_disco = solve(prob, Vern7(); callback = cb2, reltol = 1e-6) +# 83.666 μs (13326 allocations: 420.31 KiB) @profview for i in 1:1000 - solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) + solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end + #two discontinuity functions function f(u, p, t) if u[1] < 1 diff --git a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl index d77dc456c45..a0cd04b69bf 100644 --- a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl +++ b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl @@ -18,7 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, fac_default_gamma, get_current_adaptive_order, get_fsalfirstlast, isfirk, generic_solver_docstring, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, set_discontinuity + _process_AD_choice, LinearAliasSpecifier using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester isfirk, generic_solver_docstring using SciMLOperators: AbstractSciMLOperator diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index 3ee4ecaa422..2231de529dd 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -92,7 +92,6 @@ struct RadauIIA5{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <: controller::Symbol step_limiter!::StepLimiter autodiff::AD - is_disco::Bool end function RadauIIA5(; @@ -103,7 +102,7 @@ function RadauIIA5(; extrapolant = :dense, fast_convergence_cutoff = 1 // 5, new_W_γdt_cutoff = 1 // 5, controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true, - step_limiter! = trivial_limiter!, is_disco = false + step_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) @@ -124,7 +123,6 @@ function RadauIIA5(; controller, step_limiter!, AD_choice, - is_disco ) end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index ce535649b5a..499b345d13c 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -675,10 +675,6 @@ end integrator.k[4] = z2 integrator.k[5] = z3 end - else - if alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end integrator.fsallast = f(u, p, t + dt) @@ -956,10 +952,6 @@ end integrator.k[4] .= z2 integrator.k[5] .= z3 end - else - if alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end f(fsallast, u, p, t + dt) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 7017e379fad..9ef1b32e2a7 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_adaptive_order, isWmethod, isfsal, _un calculate_residuals, has_stiff_interpolation, ODEIntegrator, resize_non_user_cache!, _ode_addsteps!, full_cache, DerivativeOrderNotPossibleError, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, copyat_or_push!, set_discontinuity + _process_AD_choice, LinearAliasSpecifier, copyat_or_push! using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools: namify using MacroTools: @capture diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index c08af7f8f8c..3e2374e59ca 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -26,14 +26,13 @@ for (Alg, desc, refs, is_W) in [ step_limiter!::StepLimiter stage_limiter!::StageLimiter autodiff::AD - is_disco::Bool end function $Alg(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!, - stage_limiter! = trivial_limiter!, is_disco = false + stage_limiter! = trivial_limiter! ) AD_choice, chunk_size, diff_type = _process_AD_choice( autodiff, chunk_size, diff_type @@ -45,7 +44,7 @@ for (Alg, desc, refs, is_W) in [ typeof(stage_limiter!) }( linsolve, precs, step_limiter!, - stage_limiter!, AD_choice, is_disco + stage_limiter!, AD_choice ) end end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 038c06fa01c..0ad836ac446 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -1384,10 +1384,6 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - - if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end if integrator.opts.calck @@ -1528,9 +1524,6 @@ end integrator.opts.reltol, integrator.opts.internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst > one(eltype(integrator.EEst)) && integrator.alg.is_disco - set_discontinuity(u, uprev, integrator, cache) - end end if integrator.opts.calck From 9d306b7cf28f2d70c6fa187ccb4723b11083f03a Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:38:14 -0600 Subject: [PATCH 05/15] add disco to BDF controller --- lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl | 2 +- lib/OrdinaryDiffEqBDF/src/controllers.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index f6e404018da..50c5eb33fb8 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 3be5968b186..a14a1672d38 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -112,6 +112,13 @@ function bdf_step_reject_controller!(integrator, EEst1) h = integrator.dt integrator.cache.consfailcnt += 1 integrator.cache.nconsteps = 0 + + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + if integrator.cache.consfailcnt > 1 h = h / 2 end From 2158f3b9bb2ce9aadb623f53a92ee69a95c2dd21 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:54:04 -0600 Subject: [PATCH 06/15] fix small edits --- lib/OrdinaryDiffEqBDF/src/algorithms.jl | 2 +- lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl | 1 - lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index d6433237802..ad0aba5f193 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -841,4 +841,4 @@ function DFBDF(; ) end -@truncate_stacktrace DFBDF \ No newline at end of file +@truncate_stacktrace DFBDF diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl index bee908fe82d..59deace1b22 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl @@ -1483,7 +1483,6 @@ function perform_step!( internalnorm, t ) integrator.EEst = integrator.opts.internalnorm(atmp, t) - estimate_terk!(integrator, cache, k + 1, Val(max_order)) calculate_residuals!( atmp, _vec(terk_tmp), _vec(uprev), _vec(u), abstol, reltol, diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index 2231de529dd..aa433958379 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -122,7 +122,7 @@ function RadauIIA5(; new_W_γdt_cutoff, controller, step_limiter!, - AD_choice, + AD_choice ) end From 2b9a5097a75769e5376f8bc20a37177bdfd23993 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 3 Mar 2026 11:58:45 -0600 Subject: [PATCH 07/15] Update algorithms.jl --- lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl index 3e2374e59ca..cc5a11dedd3 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl @@ -41,7 +41,7 @@ for (Alg, desc, refs, is_W) in [ _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac), typeof(step_limiter!), - typeof(stage_limiter!) + typeof(stage_limiter!), }( linsolve, precs, step_limiter!, stage_limiter!, AD_choice From d8c1a3be88e319b318194ee6dce35e772b7a724f Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 16 Mar 2026 16:51:27 -0400 Subject: [PATCH 08/15] update disco scheme by caching problems in integrator --- lib/OrdinaryDiffEqCore/src/disco.jl | 53 +++-- .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 23 ++- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 186 ++++++++---------- 4 files changed, 130 insertions(+), 133 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 20698ade0f1..f402034a376 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -1,4 +1,4 @@ -function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to test +function set_discontinuity(u, uprev, integrator, cache) breakpointθ = find_discontinuity(u, uprev, integrator, cache) dt = integrator.dt t = integrator.t @@ -10,7 +10,6 @@ function set_discontinuity(u, uprev, integrator, cache) #need to pick algs to te end function find_discontinuity(u, uprev, integrator, cache) - println("Finding discontinuity...") cb = integrator.opts.callback cb === nothing && return -1 isempty(cb.continuous_callbacks) && return -1 @@ -18,49 +17,45 @@ function find_discontinuity(u, uprev, integrator, cache) t = integrator.t dt = integrator.dt breakpointθ = -one(dt) - prob = nothing + idx = 1 for i in cb.continuous_callbacks if (!(i.is_discontinuity)) continue end - out_prev = nothing - out_curr = nothing - is_inplace = DiffEqBase.isinplace(i.condition, 4) - if is_inplace + if (i isa VectorContinuousCallback) out_prev = similar(u) + out_curr = similar(u) i.condition(out_prev, uprev, t, integrator) - out_curr = similar(u) i.condition(out_curr, u, t + dt, integrator) - is_inplace = true - else - out_prev = i.condition(uprev, t, integrator) - out_curr = i.condition(u, t + dt, integrator) - is_inplace = false - end - for (idx, (f0, f1)) in enumerate(zip(out_prev, out_curr)) - if (f0 * f1 < zero(f0)) - function zero_func(θ, p) + for (ind, (f0, f1)) in enumerate(zip(out_prev, out_curr)) + if (f0 * f1 < zero(f0)) u₁ = similar(u) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - if is_inplace - out = similar(u) - i.condition(out, u₁, t + θ * dt, integrator) - else - out = i.condition(u₁, t + θ * dt, integrator) + out = similar(u) + function zero_func(θ, p) + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + i.condition(out, u₁, t + θ * integrator.dt, integrator) + out[ind] end - out[idx] - end - if prob === nothing prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) - else - prob = remake(prob; f=zero_func) + sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end end - sol = solve(prob; bracket=[zero(dt), one(dt)]) + end + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + if (out_prev * out_curr < zero(out_prev)) + prob = integrator.disco_probs[idx] + sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp end end + idx += 1 end end breakpointθ diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 6234de5143b..4f4fbc62d1a 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -146,4 +146,5 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + disco_probs::Vector{IntervalNonlinearProblem} end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index b9ce19f7622..a1bdae6b409 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -81,6 +81,7 @@ function SciMLBase.__init( alias = ODEAliasSpecifier(), initializealg = DefaultInit(), rng = nothing, + disco_probs = nothing, kwargs... ) if prob isa SciMLBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm @@ -653,6 +654,26 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng + num_cb = 0 + for i in callbacks_internal.continuous_callbacks + num_cb += 1 + end + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_cb) + idx = 1 + for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) + if i.is_discontinuity && !(i isa VectorContinuousCallback) + #VCC problems handled in disco itself + u₁ = similar(u) + function zero_func(θ, p) + ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + out = i.condition(u₁, t + θ * integrator.dt, integrator) + out + end + disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + disco_probs[idx] = disco_prob + end + idx+=1 + end integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), @@ -686,7 +707,7 @@ function SciMLBase.__init( isout, reeval_fsal, u_modified, reinitiailize, isdae, opts, stats, initializealg, differential_vars, - fsalfirst, fsallast, _rng + fsalfirst, fsallast, _rng, disco_probs ) if initialize_integrator diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 5f3b80afa7c..ba56111e12b 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,6 +1,7 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +#TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] u0 = [0.1] @@ -10,38 +11,22 @@ prob = ODEProblem(f, u0, tspan) #define callback condition(u, t, integrator) = u[1] - 1 function affect!(integrator) + #println("fired callback at t=$(integrator.t), u=$(integrator.u[1])") integrator.u[1] += 10 end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 291.292 μs (8449 allocations: 266.47 KiB) +# 277.833 μs (8033 allocations: 251.14 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 335.417 μs (10008 allocations: 311.08 KiB) - -rodas_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) -# 410.291 μs (16828 allocations: 594.05 KiB) -rodas_no_disco = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) -# 483.792 μs (17729 allocations: 639.31 KiB) - -bdf_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) -# 245.917 μs (20703 allocations: 665.16 KiB) -bdf_no_disco = solve(prob, FBDF(); callback = cb2, reltol = 1e-6) -# 269.333 μs (20477 allocations: 663.80 KiB) - -tsit_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) -# same either way for some reason? check about this -tsit_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) - -vern_disco = solve(prob, Vern7(); callback = cb, reltol = 1e-6) -# 111.125 μs (15629 allocations: 493.66 KiB) -vern_no_disco = solve(prob, Vern7(); callback = cb2, reltol = 1e-6) -# 83.666 μs (13326 allocations: 420.31 KiB) +# 343.041 μs (10008 allocations: 311.02 KiB) + @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end +#TEST 2: TWO DISCONTINUITIES #two discontinuity functions function f(u, p, t) if u[1] < 1 @@ -49,7 +34,7 @@ function f(u, p, t) elseif u[1] < 2 [u[1] + 0.2] # region 2: continues increasing to hit u = 2 else - [-4u[1] + 12] # region 3: after 2, moves toward u ≈ 3 + [-4u[1] + 12] end end @@ -63,28 +48,27 @@ function affect1!(integrator) #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") end cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) +cb1f = ContinuousCallback(condition1, affect1!; is_discontinuity = false) condition2(u, t, integrator) = u[1] - 2 function affect2!(integrator) #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") end cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) +cb2f = ContinuousCallback(condition2, affect2!; is_discontinuity = false) cb = CallbackSet(cb1, cb2) +cb2 = CallbackSet(cb1f, cb2f) #disco solve -sol_disco = solve(prob, RadauIIA5(is_disco = true); callback = cb, reltol = 1e-6) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +# 1.664 ms (43703 allocations: 1.35 MiB) #fixed order solve -sol_no_disco = solve(prob, RadauIIA5(is_disco = false); callback = cb, reltol = 1e-6) - -rodas_no_disco = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) - -rodas_disco = solve(prob, Rodas5P(is_disco = true); callback = cb, reltol = 1e-6) - -bdf_no_disco = solve(prob, FBDF(); callback = cb, reltol = 1e-6) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +# 1.266 ms (37019 allocations: 1.12 MiB) -bdf_disco = solve(prob, FBDF(is_disco = true); callback = cb, reltol = 1e-6) +#TEST 3: EXPONENTIAL DISCONTINUITY # multiple exponential regions with sharp transitions function f_multi_exp!(du, u, p, t) if u[1] < 0.3 @@ -103,50 +87,32 @@ prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) #define callbacks cond_multi_1(u, t, integrator) = u[1] - 0.3 function affect_multi_1!(integrator) - println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = false) cond_multi_2(u, t, integrator) = u[1] - 0.8 function affect_multi_2!(integrator) - println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = false) cb_multi = CallbackSet(cb_multi_1, cb_multi_2) +cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve -sol_disco = solve(prob_multi, RadauIIA5(is_disco = true); callback=cb_multi, reltol=1e-7, abstol=1e-9) +sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 202.834 μs (2770 allocations: 93.23 KiB) #fixed order solve -sol_no_disco = solve(prob_multi, RadauIIA5(is_disco = false); callback=cb_multi, reltol = 1e-7, abstol = 1e-9) - -# 2D system with exponential coupling and discontinuity -function f_2d_exp!(du, u, p, t) - if u[1] + u[2] < 1.0 - du[1] = 2 * exp(u[1]) - u[2] - du[2] = -3 * u[1] + 4 * exp(u[2]) - else - du[1] = u[1] - du[2] = u[2] - end -end - -u0_2d = [0.1, 0.2] -tspan_2d = (0.0, 2.0) -prob_2d = ODEProblem(f_2d_exp!, u0_2d, tspan_2d) +sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 122.875 μs (1136 allocations: 54.52 KiB) -#define callback -cond_2d(u, t, integrator) = u[1] + u[2] - 1.0 -function affect_2d!(integrator) - println("2D exponential discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") - @test 0.98 < integrator.u[1] + integrator.u[2] < 1.02 +@profview for i in 1:1000 + solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) end -cb_2d = ContinuousCallback(cond_2d, affect_2d!; is_discontinuity = true) - -#disco solve -sol_disco = solve(prob_2d, RadauIIA5(is_disco = true); callback=cb_2d, reltol=1e-8, abstol=1e-10) -#fixed order solve -sol_no_disco = solve(prob_2d, RadauIIA5(is_disco = false); callback=cb_2d, reltol = 1e-8, abstol = 1e-10) +#TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system function f_stiff_disc!(du, u, p, t) λ = p[1] # stiffness parameter @@ -164,15 +130,20 @@ prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) #define callback cond_stiff(u, t, integrator) = u[1] - 0.5 function affect_stiff!(integrator) - println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") end cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = false) #disco solve -sol_disco = solve(prob_stiff, RadauIIA5(is_disco = true); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 131.875 μs (1956 allocations: 74.03 KiB) #fixed order solve -sol_no_disco = solve(prob_stiff, RadauIIA5(is_disco = false); callback=cb_stiff, reltol = 1e-9, abstol = 1e-11) +sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 119.417 μs (1480 allocations: 59.55 KiB) + +#TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE # multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) function f_many_disc!(du, u, p, t) du[1] = u[1] + 1 # simple linear growth @@ -187,20 +158,26 @@ disc_values = [0.1 + i * 1e-6 for i = 0:4] # define callbacks for each discontinuity cbs_many = [] +cbs_many_f = [] for (i, disc_val) in enumerate(disc_values) local cond_func(u, t, integrator) = u[1] - disc_val function affect_func!(integrator) - println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") + #println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") end push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) + push!(cbs_many_f, ContinuousCallback(cond_func, affect_func!; is_discontinuity = false)) end cb_many = CallbackSet(cbs_many...) +cb_many_f = CallbackSet(cbs_many_f...) #disco solve -sol_disco = solve(prob_many, RadauIIA5(is_disco = true); callback=cb_many, reltol=1e-10, abstol=1e-12) +sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) +# 111.333 μs (907 allocations: 36.94 KiB) #fixed order solve -sol_no_disco = solve(prob_many, RadauIIA5(is_disco = false); callback=cb_many, reltol=1e-10, abstol=1e-12) +sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) +# 111.666 μs (907 allocations: 36.94 KiB) +#TEST 6: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) # du[1]/dt = u[2] - u[1] @@ -228,41 +205,44 @@ function affect_dae!(integrator) #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") end cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) + +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 83.500 μs (769 allocations: 35.72 KiB) +radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 101.542 μs (1230 allocations: 48.16 KiB) + +#TEST 7: VECTOR CALLBACK +function f!(du, u, p, t) + du[1] = -u[1] + du[2] = 0.2*u[1] - 0.1*u[2] +end + +u0 = [3.0, 0.0] +tspan = (0.0, 10.0) +prob = ODEProblem(f!, u0, tspan) + +# Two event surfaces: u[1] == 2.0 and u[1] == 1.0 +function condition!(out, u, t, integrator) + out[1] = u[1] - 2.0 + out[2] = u[1] - 1.0 +end + +# Discontinuous update to the state when an event fires +function affect!(integrator, idx) + if idx == 1 + # when u[1] crosses 2, kick u[2] up (jump discontinuity) + integrator.u[2] += 5.0 + elseif idx == 2 + # when u[1] crosses 1, reset u[2] + integrator.u[2] = 0.0 + end +end +cb = VectorContinuousCallback(condition!, affect!, 2;) +cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) -radau_no_disco = solve(prob_dae, RadauIIA5(is_disco = false); callback=cb_dae, reltol=1e-8, abstol=1e-10) - #83.500 μs (769 allocations: 35.72 KiB) -radau_disco = solve(prob_dae, RadauIIA5(is_disco = true); callback=cb_dae, reltol=1e-8, abstol=1e-10) - # 119.417 μs (1273 allocations: 55.42 KiB) -rodas_no_disco = solve(prob_dae, Rodas5P(); callback = cb_dae, reltol = 1e-6) -#= SciMLBase.DEStats -Number of function 1 evaluations: 312 -Number of function 2 evaluations: 0 -Number of W matrix evaluations: 34 -Number of linear solves: 272 -Number of Jacobians created: 19 -Number of nonlinear solver iterations: 0 -Number of nonlinear solver convergence failures: 0 -Number of fixed-point solver iterations: 0 -Number of fixed-point solver convergence failures: 0 -Number of rootfind condition calls: 213 -Number of accepted steps: 19 -Number of rejected steps: 15 =# -# 98.167 μs (550 allocations: 26.92 KiB) -rodas_disco = solve(prob_dae, Rodas5P(is_disco = true); callback = cb_dae, reltol = 1e-6) -#= SciMLBase.DEStats -Number of function 1 evaluations: 312 -Number of function 2 evaluations: 0 -Number of W matrix evaluations: 34 -Number of linear solves: 272 -Number of Jacobians created: 19 -Number of nonlinear solver iterations: 0 -Number of nonlinear solver convergence failures: 0 -Number of fixed-point solver iterations: 0 -Number of fixed-point solver convergence failures: 0 -Number of rootfind condition calls: 213 -Number of accepted steps: 19 -Number of rejected steps: 15 =# -# 97.541 μs (550 allocations: 26.92 KiB) -bdf_no_disco = solve(prob_dae, FBDF(); callback = cb_dae, reltol = 1e-6) -bdf_disco = solve(prob_dae, FBDF(is_disco = true); callback = cb_dae, reltol = 1e-6) \ No newline at end of file +sol_disco = solve(prob, RadauIIA5(); callback = cb) +# 62.041 μs (849 allocations: 41.64 KiB) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) +# 37.375 μs (531 allocations: 25.23 KiB) \ No newline at end of file From a293d6b557846a71222cb2ebfbf98ad80b64726e Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 16 Mar 2026 16:56:05 -0400 Subject: [PATCH 09/15] small optimization --- lib/OrdinaryDiffEqCore/src/solve.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index a1bdae6b409..73eade112ba 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -654,11 +654,13 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng - num_cb = 0 + disco_cb_num = 0 for i in callbacks_internal.continuous_callbacks - num_cb += 1 + if i.is_discontinuity + disco_cb_num += 1 + end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, num_cb) + disco_probs = Vector{IntervalNonlinearProblem}(undef, disco_cb_num) idx = 1 for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) if i.is_discontinuity && !(i isa VectorContinuousCallback) From 6aa6dfa1522f8d4ed171f01d1841bc442473d224 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 28 Mar 2026 15:48:48 -0500 Subject: [PATCH 10/15] update disco to new approach --- lib/OrdinaryDiffEqCore/src/disco.jl | 15 ++++- .../src/integrators/type.jl | 1 + lib/OrdinaryDiffEqCore/src/solve.jl | 61 ++++++++++++++----- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 4 ++ 4 files changed, 64 insertions(+), 17 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index f402034a376..79e7eb72fa0 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -48,12 +48,21 @@ function find_discontinuity(u, uprev, integrator, cache) out_prev = i.condition(uprev, t, integrator) out_curr = i.condition(u, t + dt, integrator) if (out_prev * out_curr < zero(out_prev)) - prob = integrator.disco_probs[idx] - sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + disco_prob = integrator.disco_probs[idx] + #disco_prob = integrator.disco_prob + disco_prob.f.f.dt = integrator.dt + disco_prob.f.f.uprev = uprev + disco_prob.f.f.u = u + disco_prob.f.f.k = integrator.k + disco_prob.f.f.cache = integrator.cache + disco_prob.f.f.differential_vars = integrator.differential_vars + disco_prob.f.f.idxs = integrator.opts.save_idxs + #disco_prob.f.f.callback = i + sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp - end + end end idx += 1 end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 4f4fbc62d1a..c28e213e000 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -146,5 +146,6 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + #disco_prob::IntervalNonlinearProblem disco_probs::Vector{IntervalNonlinearProblem} end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 73eade112ba..da56cfddf6d 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,6 +16,27 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere +mutable struct zero_func_struct{uType, tType, rateType, CacheType} + #integrator_ref::IntegratorType + u₁::uType + callback::ContinuousCallback + dt::tType + uprev::uType + u::uType + k::Vector{rateType} + cache::CacheType + idxs::Union{Nothing, Vector{Int}} + differential_vars::Union{Nothing, Vector{Bool}} +end + +function (z::zero_func_struct)(θ, p) + #integrator = z.integrator_ref[]::ODEIntegrator + ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) + #ode_interpolant!(z.u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) + out = z.callback.condition(z.u₁, z.dt + θ * z.dt, z) + out +end + function SciMLBase.__init( prob::Union{ SciMLBase.AbstractODEProblem, @@ -654,29 +675,38 @@ function SciMLBase.__init( fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng - disco_cb_num = 0 + num_probs = 0 + integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity - disco_cb_num += 1 + if !(i isa VectorContinuousCallback) && i.is_discontinuity + num_probs += 1 end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, disco_cb_num) + + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) idx = 1 - for (ind, i) in enumerate(callbacks_internal.continuous_callbacks) + for i in callbacks_internal.continuous_callbacks + if i.is_discontinuity && !(i isa VectorContinuousCallback) + u₁ = similar(u) + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) + disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + idx += 1 + end + end + #= + disco_prob = nothing + integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) + for i in callbacks_internal.continuous_callbacks if i.is_discontinuity && !(i isa VectorContinuousCallback) #VCC problems handled in disco itself u₁ = similar(u) - function zero_func(θ, p) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - out = i.condition(u₁, t + θ * integrator.dt, integrator) - out - end + #zero_func = zero_func_struct(integrator_ref, u₁, i) + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) - disco_probs[idx] = disco_prob + break end - idx+=1 - end - + end + =# integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -711,6 +741,9 @@ function SciMLBase.__init( opts, stats, initializealg, differential_vars, fsalfirst, fsallast, _rng, disco_probs ) + #if (num_probs > 0) + integrator_ref[] = integrator + #end if initialize_integrator if isdae || SciMLBase.has_initializeprob(prob.f) || diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index ba56111e12b..db7903f2e2e 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,7 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using Logging +global_logger(ConsoleLogger(stderr, Logging.Error)) #TEST 1: SIMPLE DISCONTINUITY #test example discontinuous at u = 1 @@ -19,6 +21,7 @@ cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 277.833 μs (8033 allocations: 251.14 KiB) +# curr update: 287.417 μs (8240 allocations: 258.56 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 343.041 μs (10008 allocations: 311.02 KiB) @@ -104,6 +107,7 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) # 202.834 μs (2770 allocations: 93.23 KiB) +# curr update: 238.416 μs (4426 allocations: 119.88 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 122.875 μs (1136 allocations: 54.52 KiB) From b3fa9e81e4039db94e3ff1c50981b50b6e8f8fcf Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sun, 29 Mar 2026 22:17:45 -0500 Subject: [PATCH 11/15] update benchmarks --- lib/OrdinaryDiffEqCore/src/disco.jl | 2 +- lib/OrdinaryDiffEqCore/src/solve.jl | 2 +- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 33 +++++++++++++--------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 79e7eb72fa0..29bf26c090a 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -58,7 +58,7 @@ function find_discontinuity(u, uprev, integrator, cache) disco_prob.f.f.differential_vars = integrator.differential_vars disco_prob.f.f.idxs = integrator.opts.save_idxs #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index da56cfddf6d..287e8aa8985 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -26,7 +26,7 @@ mutable struct zero_func_struct{uType, tType, rateType, CacheType} k::Vector{rateType} cache::CacheType idxs::Union{Nothing, Vector{Int}} - differential_vars::Union{Nothing, Vector{Bool}} + differential_vars::Union{Nothing, Vector{Bool}, BitVector} end function (z::zero_func_struct)(θ, p) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index db7903f2e2e..1476e4a3946 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) @@ -20,11 +20,9 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 277.833 μs (8033 allocations: 251.14 KiB) -# curr update: 287.417 μs (8240 allocations: 258.56 KiB) +# 286.125 μs (8207 allocations: 258.09 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 343.041 μs (10008 allocations: 311.02 KiB) - +# 340.292 μs (10009 allocations: 311.05 KiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end @@ -64,11 +62,14 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.664 ms (43703 allocations: 1.35 MiB) +# 1.548 ms (46763 allocations: 1.35 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 1.266 ms (37019 allocations: 1.12 MiB) +# 1.264 ms (37026 allocations: 1.13 MiB) +@profview for i in 1:1000 + solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +end #TEST 3: EXPONENTIAL DISCONTINUITY @@ -106,11 +107,10 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 202.834 μs (2770 allocations: 93.23 KiB) -# curr update: 238.416 μs (4426 allocations: 119.88 KiB) +# 195.666 μs (3834 allocations: 110.72 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) -# 122.875 μs (1136 allocations: 54.52 KiB) +# 125.583 μs (1134 allocations: 54.56 KiB) @profview for i in 1:1000 solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) @@ -141,11 +141,14 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 131.875 μs (1956 allocations: 74.03 KiB) +# 131.375 μs (2181 allocations: 78.84 KiB) #fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 119.417 μs (1480 allocations: 59.55 KiB) +@profview for i in 1:1000 + solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) +end #TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE # multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) @@ -176,11 +179,15 @@ cb_many_f = CallbackSet(cbs_many_f...) #disco solve sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 111.333 μs (907 allocations: 36.94 KiB) +# 169.541 μs (1479 allocations: 73.98 KiB) #fixed order solve sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) # 111.666 μs (907 allocations: 36.94 KiB) +@profview for i in 1:1000 + solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) +end + #TEST 6: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) @@ -214,7 +221,7 @@ cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 83.500 μs (769 allocations: 35.72 KiB) radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) -# 101.542 μs (1230 allocations: 48.16 KiB) +# 104.292 μs (1494 allocations: 53.25 KiB) #TEST 7: VECTOR CALLBACK function f!(du, u, p, t) From af7c3a50f7b91aab9a1ad0c6fa512178e6aab618 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 30 Mar 2026 19:11:37 -0500 Subject: [PATCH 12/15] further optimizations --- lib/OrdinaryDiffEqCore/src/disco.jl | 24 ++++++++++++++-------- lib/OrdinaryDiffEqCore/src/solve.jl | 10 ++++----- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 6 +++--- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index 29bf26c090a..b8abb3bb633 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -16,6 +16,13 @@ function find_discontinuity(u, uprev, integrator, cache) p = integrator.p t = integrator.t dt = integrator.dt + save_idxs = integrator.opts.save_idxs + k = integrator.k + cache = integrator.cache + differential_vars = integrator.differential_vars + θlo = zero(dt) + θhi = one(dt) + bracket = [θlo, θhi] breakpointθ = -one(dt) idx = 1 for i in cb.continuous_callbacks @@ -50,15 +57,16 @@ function find_discontinuity(u, uprev, integrator, cache) if (out_prev * out_curr < zero(out_prev)) disco_prob = integrator.disco_probs[idx] #disco_prob = integrator.disco_prob - disco_prob.f.f.dt = integrator.dt - disco_prob.f.f.uprev = uprev - disco_prob.f.f.u = u - disco_prob.f.f.k = integrator.k - disco_prob.f.f.cache = integrator.cache - disco_prob.f.f.differential_vars = integrator.differential_vars - disco_prob.f.f.idxs = integrator.opts.save_idxs + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket = bracket, abstol = 0, reltol = 0) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 287e8aa8985..220f7e50d6e 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,17 +16,17 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere -mutable struct zero_func_struct{uType, tType, rateType, CacheType} +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType} #integrator_ref::IntegratorType u₁::uType - callback::ContinuousCallback + callback::callbackType dt::tType uprev::uType u::uType - k::Vector{rateType} + k::kType cache::CacheType - idxs::Union{Nothing, Vector{Int}} - differential_vars::Union{Nothing, Vector{Bool}, BitVector} + idxs::idxsType + differential_vars::varsType end function (z::zero_func_struct)(θ, p) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 1476e4a3946..54b9f900a42 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -20,7 +20,7 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 286.125 μs (8207 allocations: 258.09 KiB) +# 283.292 μs (8113 allocations: 256.59 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 340.292 μs (10009 allocations: 311.05 KiB) @profview for i in 1:1000 @@ -62,7 +62,7 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.548 ms (46763 allocations: 1.35 MiB) +# 1.460 ms (41491 allocations: 1.26 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 1.264 ms (37026 allocations: 1.13 MiB) @@ -107,7 +107,7 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 195.666 μs (3834 allocations: 110.72 KiB) +# 159.125 μs (1819 allocations: 79.06 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 125.583 μs (1134 allocations: 54.56 KiB) From 3bd834fb7b706f1aa0a41f9281f55c7bbb501539 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Tue, 31 Mar 2026 10:26:31 -0500 Subject: [PATCH 13/15] fix benchmarks and merge issues --- lib/OrdinaryDiffEqCore/src/solve.jl | 1 - lib/OrdinaryDiffEqCore/test/disco_tests.jl | 28 +++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index a2f8a5efdb3..644672abdb6 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -822,7 +822,6 @@ function _ode_init( u_modified, reinitialize, isdae, opts, stats, initializealg, differential_vars, fsalfirst, fsallast, _rng, disco_probs, - fsalfirst, fsallast, _rng, W, P, sqdt, noise, c, rate_constants, QT(1) ) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 54b9f900a42..246be2015f1 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -20,9 +20,9 @@ cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 283.292 μs (8113 allocations: 256.59 KiB) +# 298.084 μs (8108 allocations: 257.11 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 340.292 μs (10009 allocations: 311.05 KiB) +# 356.708 μs (10024 allocations: 312.08 KiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) end @@ -62,10 +62,10 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -# 1.460 ms (41491 allocations: 1.26 MiB) +# 1.503 ms (41672 allocations: 1.27 MiB) #fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) -# 1.264 ms (37026 allocations: 1.13 MiB) +# 1.306 ms (37092 allocations: 1.13 MiB) @profview for i in 1:1000 solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) @@ -107,10 +107,10 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) -# 159.125 μs (1819 allocations: 79.06 KiB) +# 175.625 μs (1871 allocations: 81.55 KiB) #fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) -# 125.583 μs (1134 allocations: 54.56 KiB) +# 142.875 μs (1244 allocations: 59.17 KiB) @profview for i in 1:1000 solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) @@ -141,10 +141,10 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) -# 131.375 μs (2181 allocations: 78.84 KiB) +# 149.167 μs (1819 allocations: 75.19 KiB) #fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) -# 119.417 μs (1480 allocations: 59.55 KiB) +# 138.125 μs (1565 allocations: 64.09 KiB) @profview for i in 1:1000 solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) @@ -179,10 +179,10 @@ cb_many_f = CallbackSet(cbs_many_f...) #disco solve sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 169.541 μs (1479 allocations: 73.98 KiB) +# 182.541 μs (1489 allocations: 73.64 KiB) #fixed order solve sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) -# 111.666 μs (907 allocations: 36.94 KiB) +# 121.292 μs (923 allocations: 36.78 KiB) @profview for i in 1:1000 solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) @@ -218,11 +218,11 @@ end cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) -radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) -# 83.500 μs (769 allocations: 35.72 KiB) radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) -# 104.292 μs (1494 allocations: 53.25 KiB) - +# 88.542 μs (870 allocations: 41.86 KiB) +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 73.000 μs (673 allocations: 32.05 KiB) + #TEST 7: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] From 90a9bca6eec6dcb619fc7db2bd7ea42f17f56e33 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Mon, 6 Apr 2026 19:25:15 -0500 Subject: [PATCH 14/15] update vector continuous callback handling --- lib/OrdinaryDiffEqCore/src/disco.jl | 44 ++++------ .../src/integrators/integrator_interface.jl | 2 +- lib/OrdinaryDiffEqCore/src/solve.jl | 45 ++++------ lib/OrdinaryDiffEqCore/test/disco_tests.jl | 82 +++++++++---------- 4 files changed, 74 insertions(+), 99 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl index b8abb3bb633..9f7191874e7 100644 --- a/lib/OrdinaryDiffEqCore/src/disco.jl +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -29,51 +29,43 @@ function find_discontinuity(u, uprev, integrator, cache) if (!(i.is_discontinuity)) continue end + disco_prob = integrator.disco_probs[idx] + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs if (i isa VectorContinuousCallback) + len_cb = i.len out_prev = similar(u) - out_curr = similar(u) + out_curr = similar(u) i.condition(out_prev, uprev, t, integrator) i.condition(out_curr, u, t + dt, integrator) - for (ind, (f0, f1)) in enumerate(zip(out_prev, out_curr)) - if (f0 * f1 < zero(f0)) - u₁ = similar(u) - out = similar(u) - function zero_func(θ, p) - ode_interpolant!(u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - i.condition(out, u₁, t + θ * integrator.dt, integrator) - out[ind] - end - prob = IntervalNonlinearProblem(zero_func, [zero(dt), one(dt)], p) - sol = solve(prob; bracket=[zero(dt), one(dt)], abstol = 0, reltol = 0) + for j in 1:len_cb + if (out_prev[j] * out_curr[j] < zero(out_prev[j])) + disco_zero.ind = j + sol = solve(disco_prob; bracket = bracket) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp - end + end end end else out_prev = i.condition(uprev, t, integrator) out_curr = i.condition(u, t + dt, integrator) if (out_prev * out_curr < zero(out_prev)) - disco_prob = integrator.disco_probs[idx] - #disco_prob = integrator.disco_prob - disco_zero = disco_prob.f.f - disco_zero.dt = dt - disco_zero.uprev = uprev - disco_zero.u = u - disco_zero.k = k - disco_zero.cache = cache - disco_zero.differential_vars = differential_vars - disco_zero.idxs = save_idxs - #disco_prob.f.f.callback = i - sol = solve(disco_prob; bracket = bracket, abstol = 0, reltol = 0) + sol = solve(disco_prob; bracket = bracket) tmp = sol[] if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) breakpointθ = tmp end end - idx += 1 end + idx += 1 end breakpointθ end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 8302ee706b6..4d988c21a57 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -144,7 +144,7 @@ end end end -function u_modified!(integrator::ODEIntegrator, bool::Bool) +function SciMLBase.u_modified!(integrator::ODEIntegrator, bool::Bool) return integrator.u_modified = bool end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 644672abdb6..c1d00c70cce 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,7 +16,7 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere -mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType} +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType} #integrator_ref::IntegratorType u₁::uType callback::callbackType @@ -27,14 +27,19 @@ mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsTy cache::CacheType idxs::idxsType differential_vars::varsType + ind::Int + out::outType end function (z::zero_func_struct)(θ, p) - #integrator = z.integrator_ref[]::ODEIntegrator ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) - #ode_interpolant!(z.u₁, θ, integrator, integrator.opts.save_idxs, Val{0}) - out = z.callback.condition(z.u₁, z.dt + θ * z.dt, z) - out + return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind) +end + +@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z) +@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind) + cb.condition(out, u, t, z) + return out[ind] end function SciMLBase.__init( @@ -755,38 +760,25 @@ function _ode_init( get_fsalfirstlast(cache, rate_prototype) _rng = rng === nothing ? Random.default_rng() : rng + num_probs = 0 - integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) for i in callbacks_internal.continuous_callbacks - if !(i isa VectorContinuousCallback) && i.is_discontinuity + if i.is_discontinuity num_probs += 1 end end - disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) idx = 1 for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity && !(i isa VectorContinuousCallback) + if i.is_discontinuity u₁ = similar(u) - zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) + out = i isa VectorContinuousCallback ? similar(u) : nothing + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out) disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) idx += 1 end end - #= - disco_prob = nothing - integrator_ref = Ref{Union{DEIntegrator, Nothing}}(nothing) - for i in callbacks_internal.continuous_callbacks - if i.is_discontinuity && !(i isa VectorContinuousCallback) - #VCC problems handled in disco itself - u₁ = similar(u) - #zero_func = zero_func_struct(integrator_ref, u₁, i) - zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars) - disco_prob = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) - break - end - end - =# + integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -798,7 +790,7 @@ function _ode_init( typeof(initializealg), typeof(differential_vars), typeof(controller_cache), typeof(_rng), typeof(W), typeof(P), typeof(sqdt), - typeof(noise), typeof(c), typeof(rate_constants), + typeof(noise), typeof(c), typeof(rate_constants) }( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, @@ -825,9 +817,6 @@ function _ode_init( W, P, sqdt, noise, c, rate_constants, QT(1) ) - #if (num_probs > 0) - integrator_ref[] = integrator - #end if initialize_integrator if isdae || SciMLBase.has_initializeprob(prob.f) || diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 246be2015f1..89ca91e34c3 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -150,45 +150,7 @@ sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9 solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) end -#TEST 5: MULTIPLE DISCONTINUITIES IN SMALL RANGE -# multiple discontinuities in very small range (1e-6 apart, 5 discontinuities) -function f_many_disc!(du, u, p, t) - du[1] = u[1] + 1 # simple linear growth -end - -u0_many = [0.0] -tspan_many = (0.0, 1.0) -prob_many = ODEProblem(f_many_disc!, u0_many, tspan_many) - -# create 5 discontinuities spaced 1e-6 apart -disc_values = [0.1 + i * 1e-6 for i = 0:4] - -# define callbacks for each discontinuity -cbs_many = [] -cbs_many_f = [] -for (i, disc_val) in enumerate(disc_values) - local cond_func(u, t, integrator) = u[1] - disc_val - function affect_func!(integrator) - #println("Dense discontinuity $i fired at t=$(integrator.t), u=$(integrator.u[1])") - end - push!(cbs_many, ContinuousCallback(cond_func, affect_func!; is_discontinuity = true)) - push!(cbs_many_f, ContinuousCallback(cond_func, affect_func!; is_discontinuity = false)) -end -cb_many = CallbackSet(cbs_many...) -cb_many_f = CallbackSet(cbs_many_f...) - -#disco solve -sol_disco = solve(prob_many, RadauIIA5(); callback=cb_many, reltol=1e-10, abstol=1e-12) -# 182.541 μs (1489 allocations: 73.64 KiB) -#fixed order solve -sol_no_disco = solve(prob_many, RadauIIA5(); callback=cb_many_f, reltol=1e-10, abstol=1e-12) -# 121.292 μs (923 allocations: 36.78 KiB) - -@profview for i in 1:1000 - solve(prob_many, RadauIIA5(); callback = cb_many, reltol = 1e-10, abstol = 1e-12) -end - -#TEST 6: DISCONTINUOUS DAE +#TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix # System: M * du/dt = f(u, p, t) # du[1]/dt = u[2] - u[1] @@ -223,7 +185,7 @@ radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol= radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 73.000 μs (673 allocations: 32.05 KiB) -#TEST 7: VECTOR CALLBACK +#TEST 6: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] du[2] = 0.2*u[1] - 0.1*u[2] @@ -250,10 +212,42 @@ function affect!(integrator, idx) end end -cb = VectorContinuousCallback(condition!, affect!, 2;) -cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) +cb = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = true) +cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) sol_disco = solve(prob, RadauIIA5(); callback = cb) -# 62.041 μs (849 allocations: 41.64 KiB) +# 49.125 μs (664 allocations: 32.89 KiB) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) -# 37.375 μs (531 allocations: 25.23 KiB) \ No newline at end of file +# 37.375 μs (531 allocations: 25.23 KiB) + +@profview for i in 1:1000 + solve(prob, RadauIIA5(); callback = cb) +end + +#TEST 7 +function f!(du, u, p, t) + x1, x2 = u + du[1] = x2 + if x2 < 0.0 + du[2] = -x1 + 1.0 + else + du[2] = -x1 - 1.0 + end +end + +u = [1.5, 0.8] +tspan = (0.0, 2.0) +prob = ODEProblem(f!, u, tspan) + +cond(u, t, integrator) = u[2] +affect!(integrator) = nothing + +cb = ContinuousCallback(cond, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) + +sol_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) + +@profview for i in 1:1000 + solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +end From 6902dcea6b098f0548bb484e2a97ce434b444d96 Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 18 Apr 2026 15:39:27 -0500 Subject: [PATCH 15/15] update tests --- lib/OrdinaryDiffEqCore/test/disco_tests.jl | 91 +++++++++++++++------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl index 89ca91e34c3..99374969fcb 100644 --- a/lib/OrdinaryDiffEqCore/test/disco_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra -using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock using Logging global_logger(ConsoleLogger(stderr, Logging.Error)) @@ -19,13 +19,20 @@ end cb = ContinuousCallback(condition, affect!; is_discontinuity = true) cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) -sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 298.084 μs (8108 allocations: 257.11 KiB) -sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 356.708 μs (10024 allocations: 312.08 KiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -end + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 418.584 μs (16472 allocations: 576.75 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 440.375 μs (17875 allocations: 622.09 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 59.542 μs (7248 allocations: 233.67 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 46.500 μs (7129 allocations: 226.22 KiB) #TEST 2: TWO DISCONTINUITIES #two discontinuity functions @@ -63,13 +70,18 @@ cb2 = CallbackSet(cb1f, cb2f) #disco solve sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) # 1.503 ms (41672 allocations: 1.27 MiB) -#fixed order solve sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) # 1.306 ms (37092 allocations: 1.13 MiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 1.164 ms (44318 allocations: 1.52 MiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 1.306 ms (51713 allocations: 1.76 MiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 279.792 μs (34573 allocations: 1.07 MiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 266.167 μs (39024 allocations: 1.21 MiB) #TEST 3: EXPONENTIAL DISCONTINUITY @@ -108,13 +120,18 @@ cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) #disco solve sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) # 175.625 μs (1871 allocations: 81.55 KiB) -#fixed order solve sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) # 142.875 μs (1244 allocations: 59.17 KiB) -@profview for i in 1:1000 - solve(prob_multi, RadauIIA5(); callback = cb_multi, reltol = 1e-6) -end +sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 295.834 μs (2216 allocations: 90.70 KiB) +sol_no_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) +# 253.709 μs (1380 allocations: 74.28 KiB) + +sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 127.375 μs (1953 allocations: 87.49 KiB) +sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 95.250 μs (1499 allocations: 73.62 KiB) #TEST 4: STIFF DISCONTINUITY # very stiff discontinuous system @@ -142,13 +159,18 @@ cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = fa #disco solve sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) # 149.167 μs (1819 allocations: 75.19 KiB) -#fixed order solve sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) # 138.125 μs (1565 allocations: 64.09 KiB) -@profview for i in 1:1000 - solve(prob_stiff, RadauIIA5(); callback = cb_stiff, reltol = 1e-9, abstol = 1e-11) -end +sol_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 204.833 μs (1517 allocations: 59.33 KiB) +sol_no_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) +# 156.500 μs (1047 allocations: 44.59 KiB) + +sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 93.833 μs (2040 allocations: 80.59 KiB) +sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 82.750 μs (1898 allocations: 72.12 KiB) #TEST 5: DISCONTINUOUS DAE # discontinuous DAE with mass matrix @@ -185,6 +207,11 @@ radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol= radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) # 73.000 μs (673 allocations: 32.05 KiB) +sol_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 312.167 μs (1200 allocations: 48.73 KiB) +sol_no_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 256.792 μs (672 allocations: 32.56 KiB) + #TEST 6: VECTOR CALLBACK function f!(du, u, p, t) du[1] = -u[1] @@ -220,9 +247,15 @@ sol_disco = solve(prob, RadauIIA5(); callback = cb) sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) # 37.375 μs (531 allocations: 25.23 KiB) -@profview for i in 1:1000 - solve(prob, RadauIIA5(); callback = cb) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb) +# 57.333 μs (592 allocations: 31.23 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2) +# 44.250 μs (476 allocations: 23.73 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb) +# 37.833 μs (673 allocations: 31.80 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2) +# 24.958 μs (557 allocations: 24.23 KiB) #TEST 7 function f!(du, u, p, t) @@ -245,9 +278,15 @@ affect!(integrator) = nothing cb = ContinuousCallback(cond, affect!; is_discontinuity = true) cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) -sol_disco = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) -sol_no_disco = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) -@profview for i in 1:1000 - solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) -end +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 240.291 μs (1821 allocations: 71.56 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 184.625 μs (1029 allocations: 49.23 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 79.791 μs (1678 allocations: 73.85 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 55.958 μs (1259 allocations: 57.04 KiB) \ No newline at end of file