diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index e6e9a07d884..e5a54f67737 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index 596184a80f8..638c1e9a97c 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -112,6 +112,13 @@ function bdf_step_reject_controller!(integrator, EEst1) h = integrator.dt integrator.cache.consfailcnt += 1 integrator.cache.nconsteps = 0 + + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + if integrator.cache.consfailcnt > 1 h = h / 2 end diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index 16311defff1..a663782a316 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -161,6 +161,7 @@ include("cache_utils.jl") include("initialize_dae.jl") include("perform_step/composite_perform_step.jl") +include("disco.jl") include("dense/generic_dense.jl") diff --git a/lib/OrdinaryDiffEqCore/src/disco.jl b/lib/OrdinaryDiffEqCore/src/disco.jl new file mode 100644 index 00000000000..9f7191874e7 --- /dev/null +++ b/lib/OrdinaryDiffEqCore/src/disco.jl @@ -0,0 +1,71 @@ +function set_discontinuity(u, uprev, integrator, cache) + breakpointθ = find_discontinuity(u, uprev, integrator, cache) + dt = integrator.dt + t = integrator.t + if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0 + #println("Discontinuity detected at t = ", t + breakpointθ * dt) + return breakpointθ * dt + end + return -1 +end + +function find_discontinuity(u, uprev, integrator, cache) + cb = integrator.opts.callback + cb === nothing && return -1 + isempty(cb.continuous_callbacks) && return -1 + p = integrator.p + t = integrator.t + dt = integrator.dt + save_idxs = integrator.opts.save_idxs + k = integrator.k + cache = integrator.cache + differential_vars = integrator.differential_vars + θlo = zero(dt) + θhi = one(dt) + bracket = [θlo, θhi] + breakpointθ = -one(dt) + idx = 1 + for i in cb.continuous_callbacks + if (!(i.is_discontinuity)) + continue + end + disco_prob = integrator.disco_probs[idx] + disco_zero = disco_prob.f.f + disco_zero.dt = dt + disco_zero.uprev = uprev + disco_zero.u = u + disco_zero.k = k + disco_zero.cache = cache + disco_zero.differential_vars = differential_vars + disco_zero.idxs = save_idxs + if (i isa VectorContinuousCallback) + len_cb = i.len + out_prev = similar(u) + out_curr = similar(u) + i.condition(out_prev, uprev, t, integrator) + i.condition(out_curr, u, t + dt, integrator) + for j in 1:len_cb + if (out_prev[j] * out_curr[j] < zero(out_prev[j])) + disco_zero.ind = j + sol = solve(disco_prob; bracket = bracket) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + else + out_prev = i.condition(uprev, t, integrator) + out_curr = i.condition(u, t + dt, integrator) + if (out_prev * out_curr < zero(out_prev)) + sol = solve(disco_prob; bracket = bracket) + tmp = sol[] + if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ)) + breakpointθ = tmp + end + end + end + idx += 1 + end + breakpointθ +end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 4eb691880d4..c8054b6c30c 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -196,6 +196,11 @@ end function step_reject_controller!(integrator, controller::IController, alg) (; qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = qold end @@ -271,6 +276,11 @@ end function step_reject_controller!(integrator, cache::IControllerCache, alg) @assert cache.dtreject ≈ integrator.qold "Controller cache went out of sync with time stepping logic." + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt = cache.dtreject # TODO this does not look right. end @@ -351,6 +361,11 @@ end function step_reject_controller!(integrator, controller::PIController, alg) (; q11) = integrator (; qmin, gamma) = integrator.opts + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -457,6 +472,11 @@ end function step_reject_controller!(integrator, cache::PIControllerCache, alg) (; controller, q11) = cache (; qmin, gamma) = controller + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt /= min(inv(qmin), q11 / gamma) end @@ -633,6 +653,11 @@ function step_accept_controller!(integrator, controller::PIDController, alg, dt_ end function step_reject_controller!(integrator, controller::PIDController, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= integrator.qold end @@ -764,6 +789,11 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_ end function step_reject_controller!(integrator, cache::PIDControllerCache, alg) + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end return integrator.dt *= cache.dt_factor end @@ -877,6 +907,12 @@ end function step_reject_controller!(integrator, controller::PredictiveController, alg) (; dt, success_iter, qold) = integrator + disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache) + if disco_dt != -1 + integrator.dt = disco_dt + return integrator.dt + end + return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end @@ -988,6 +1024,11 @@ end function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg) (; dt, success_iter) = integrator (; qold) = cache + if (integrator.disco_dt_set) + println("using fixed dt from discontinuity handling") + integrator.disco_dt_set = false + return integrator.dt + end return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 8302ee706b6..4d988c21a57 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -144,7 +144,7 @@ end end end -function u_modified!(integrator::ODEIntegrator, bool::Bool) +function SciMLBase.u_modified!(integrator::ODEIntegrator, bool::Bool) return integrator.u_modified = bool end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 05c5c594ab6..766073b248e 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -152,6 +152,7 @@ mutable struct ODEIntegrator{ dtcache::tType dtchangeable::Bool dtpropose::tType + disco_dt_set::Bool tdir::tdirType eigen_est::eigenType EEst::EEstT @@ -191,6 +192,8 @@ mutable struct ODEIntegrator{ fsalfirst::FSALType fsallast::FSALType rng::RNGType + #disco_prob::IntervalNonlinearProblem + disco_probs::Vector{IntervalNonlinearProblem} W::WType P::PType sqdt::SqdtType diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 78319cd02b1..c1d00c70cce 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -16,6 +16,32 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) = determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts)))) determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere +mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType} + #integrator_ref::IntegratorType + u₁::uType + callback::callbackType + dt::tType + uprev::uType + u::uType + k::kType + cache::CacheType + idxs::idxsType + differential_vars::varsType + ind::Int + out::outType +end + +function (z::zero_func_struct)(θ, p) + ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars) + return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind) +end + +@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z) +@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind) + cb.condition(out, u, t, z) + return out[ind] +end + function SciMLBase.__init( prob::Union{ SciMLBase.AbstractODEProblem, @@ -57,6 +83,7 @@ function _ode_init( save_everystep = isempty(saveat), save_on = true, save_discretes = true, + disco_dt_set = false, save_start = save_everystep || isempty(saveat) || saveat isa Number || prob.tspan[1] in saveat, save_end = nothing, @@ -105,6 +132,7 @@ function _ode_init( alias = ODEAliasSpecifier(), initializealg = DefaultInit(), rng = nothing, + disco_probs = nothing, # SDE/RODE fields: accepted here so that SDE packages can delegate to # _ode_init and construct an ODEIntegrator with noise populated. save_noise = false, @@ -733,6 +761,24 @@ function _ode_init( _rng = rng === nothing ? Random.default_rng() : rng + num_probs = 0 + for i in callbacks_internal.continuous_callbacks + if i.is_discontinuity + num_probs += 1 + end + end + disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs) + idx = 1 + for i in callbacks_internal.continuous_callbacks + if i.is_discontinuity + u₁ = similar(u) + out = i isa VectorContinuousCallback ? similar(u) : nothing + zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out) + disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p) + idx += 1 + end + end + integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -744,12 +790,12 @@ function _ode_init( typeof(initializealg), typeof(differential_vars), typeof(controller_cache), typeof(_rng), typeof(W), typeof(P), typeof(sqdt), - typeof(noise), typeof(c), typeof(rate_constants), + typeof(noise), typeof(c), typeof(rate_constants) }( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, - dtpropose, tdir, eigen_est, EEst, + dtpropose, disco_dt_set, tdir, eigen_est, EEst, # TODO vvv remove these QT(qoldinit), q11, erracc, dtacc, @@ -767,7 +813,7 @@ function _ode_init( isout, reeval_fsal, u_modified, reinitialize, isdae, opts, stats, initializealg, differential_vars, - fsalfirst, fsallast, _rng, + fsalfirst, fsallast, _rng, disco_probs, W, P, sqdt, noise, c, rate_constants, QT(1) ) diff --git a/lib/OrdinaryDiffEqCore/test/disco_tests.jl b/lib/OrdinaryDiffEqCore/test/disco_tests.jl new file mode 100644 index 00000000000..99374969fcb --- /dev/null +++ b/lib/OrdinaryDiffEqCore/test/disco_tests.jl @@ -0,0 +1,292 @@ +using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra +using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock +using Logging +global_logger(ConsoleLogger(stderr, Logging.Error)) + +#TEST 1: SIMPLE DISCONTINUITY +#test example discontinuous at u = 1 +f(u, p, t) = u[1] < 1 ? [2u[1]] : [-3u[1] + 5] +u0 = [0.1] +tspan = (0.0, 1.5) +prob = ODEProblem(f, u0, tspan) + +#define callback +condition(u, t, integrator) = u[1] - 1 +function affect!(integrator) + #println("fired callback at t=$(integrator.t), u=$(integrator.u[1])") + integrator.u[1] += 10 +end +cb = ContinuousCallback(condition, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(condition, affect!; is_discontinuity = false) + +sol_disco_radau = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +# 298.084 μs (8108 allocations: 257.11 KiB) +sol_no_disco_radau = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +# 356.708 μs (10024 allocations: 312.08 KiB) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 418.584 μs (16472 allocations: 576.75 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 440.375 μs (17875 allocations: 622.09 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 59.542 μs (7248 allocations: 233.67 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 46.500 μs (7129 allocations: 226.22 KiB) + +#TEST 2: TWO DISCONTINUITIES +#two discontinuity functions +function f(u, p, t) + if u[1] < 1 + [2u[1]] # region 1: grows to hit u = 1 + elseif u[1] < 2 + [u[1] + 0.2] # region 2: continues increasing to hit u = 2 + else + [-4u[1] + 12] + end +end + +u0 = [0.1] +tspan = (0.0, 2.5) +prob = ODEProblem(f, u0, tspan) + +#define callbacks +condition1(u, t, integrator) = u[1] - 1 +function affect1!(integrator) + #println("Callback 1 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb1 = ContinuousCallback(condition1, affect1!; is_discontinuity = true) +cb1f = ContinuousCallback(condition1, affect1!; is_discontinuity = false) + +condition2(u, t, integrator) = u[1] - 2 +function affect2!(integrator) + #println("Callback 2 fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb2 = ContinuousCallback(condition2, affect2!; is_discontinuity = true) +cb2f = ContinuousCallback(condition2, affect2!; is_discontinuity = false) +cb = CallbackSet(cb1, cb2) +cb2 = CallbackSet(cb1f, cb2f) + +#disco solve +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-6) +# 1.503 ms (41672 allocations: 1.27 MiB) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-6) +# 1.306 ms (37092 allocations: 1.13 MiB) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-6) +# 1.164 ms (44318 allocations: 1.52 MiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-6) +# 1.306 ms (51713 allocations: 1.76 MiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-6) +# 279.792 μs (34573 allocations: 1.07 MiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-6) +# 266.167 μs (39024 allocations: 1.21 MiB) + + +#TEST 3: EXPONENTIAL DISCONTINUITY +# multiple exponential regions with sharp transitions +function f_multi_exp!(du, u, p, t) + if u[1] < 0.3 + du[1] = 3 * exp(3 * u[1]) # very steep exponential + elseif u[1] < 0.8 + du[1] = exp(u[1]) # slower exponential + else + du[1] = u[1] # linear + end +end + +u0_multi = [0.05] +tspan_multi = (0.0, 1.5) +prob_multi = ODEProblem(f_multi_exp!, u0_multi, tspan_multi) + +#define callbacks +cond_multi_1(u, t, integrator) = u[1] - 0.3 +function affect_multi_1!(integrator) + #println("Multi-exponential discontinuity 1 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_1 = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = true) +cb_multi_1f = ContinuousCallback(cond_multi_1, affect_multi_1!; is_discontinuity = false) + +cond_multi_2(u, t, integrator) = u[1] - 0.8 +function affect_multi_2!(integrator) + #println("Multi-exponential discontinuity 2 callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_multi_2 = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = true) +cb_multi_2f = ContinuousCallback(cond_multi_2, affect_multi_2!; is_discontinuity = false) +cb_multi = CallbackSet(cb_multi_1, cb_multi_2) +cb_multi2 = CallbackSet(cb_multi_1f, cb_multi_2f) + +#disco solve +sol_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 175.625 μs (1871 allocations: 81.55 KiB) +sol_no_disco = solve(prob_multi, RadauIIA5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 142.875 μs (1244 allocations: 59.17 KiB) + +sol_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 295.834 μs (2216 allocations: 90.70 KiB) +sol_no_disco_rosenbrock = solve(prob_multi, Rodas5P(); callback=cb_multi2, reltol=1e-7, abstol=1e-9) +# 253.709 μs (1380 allocations: 74.28 KiB) + +sol_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi, reltol=1e-7, abstol=1e-9) +# 127.375 μs (1953 allocations: 87.49 KiB) +sol_no_disco_tsit5 = solve(prob_multi, Tsit5(); callback=cb_multi2, reltol = 1e-7, abstol = 1e-9) +# 95.250 μs (1499 allocations: 73.62 KiB) + +#TEST 4: STIFF DISCONTINUITY +# very stiff discontinuous system +function f_stiff_disc!(du, u, p, t) + λ = p[1] # stiffness parameter + if u[1] < 0.5 + du[1] = -λ * u[1] + λ * exp(-t) # stiff decay with forcing + else + du[1] = u[1] + end +end + +u0_stiff = [0.1] +tspan_stiff = (0.0, 3.0) +prob_stiff = ODEProblem(f_stiff_disc!, u0_stiff, tspan_stiff, [100.0]) + +#define callback +cond_stiff(u, t, integrator) = u[1] - 0.5 +function affect_stiff!(integrator) + #println("Stiff discontinuity callback fired at t=$(integrator.t), u=$(integrator.u[1])") +end +cb_stiff = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = true) +cb_stiff_f = ContinuousCallback(cond_stiff, affect_stiff!; is_discontinuity = false) + +#disco solve +sol_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 149.167 μs (1819 allocations: 75.19 KiB) +sol_no_disco = solve(prob_stiff, RadauIIA5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 138.125 μs (1565 allocations: 64.09 KiB) + +sol_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 204.833 μs (1517 allocations: 59.33 KiB) +sol_no_disco_rosenbrock = solve(prob_stiff, Rodas5P(); callback=cb_stiff_f, reltol=1e-9, abstol=1e-11) +# 156.500 μs (1047 allocations: 44.59 KiB) + +sol_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff, reltol=1e-9, abstol=1e-11) +# 93.833 μs (2040 allocations: 80.59 KiB) +sol_no_disco_tsit5 = solve(prob_stiff, Tsit5(); callback=cb_stiff_f, reltol = 1e-9, abstol = 1e-11) +# 82.750 μs (1898 allocations: 72.12 KiB) + +#TEST 5: DISCONTINUOUS DAE +# discontinuous DAE with mass matrix +# System: M * du/dt = f(u, p, t) +# du[1]/dt = u[2] - u[1] +# 0 = u[1] + u[2] - 1 (algebraic constraint) +function f_dae_disc!(du, u, p, t) + if u[1] < 0.5 + du[1] = 2 * u[2] - u[1] + du[2] = u[1] + u[2] - 1 # algebraic constraint + else + du[1] = -u[1] + u[2] + du[2] = u[1] + u[2] - 1 # algebraic constraint + end +end + +u0_dae = [0.2, 0.8] # consistent with constraint u[1] + u[2] = 1 +tspan_dae = (0.0, 2.0) + +M_dae = [1.0 0.0; 0.0 0.0] + +f_dae_func = ODEFunction(f_dae_disc!; mass_matrix=M_dae) +prob_dae = ODEProblem(f_dae_func, u0_dae, tspan_dae) + +cond_dae(u, t, integrator) = u[1] - 0.5 +function affect_dae!(integrator) + #println("DAE discontinuity callback fired at t=$(integrator.t), u=$(integrator.u)") +end +cb_dae = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = true) +cb_daef = ContinuousCallback(cond_dae, affect_dae!; is_discontinuity = false) + +radau_disco = solve(prob_dae, RadauIIA5(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 88.542 μs (870 allocations: 41.86 KiB) +radau_no_disco = solve(prob_dae, RadauIIA5(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 73.000 μs (673 allocations: 32.05 KiB) + +sol_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_dae, reltol=1e-8, abstol=1e-10) +# 312.167 μs (1200 allocations: 48.73 KiB) +sol_no_disco_rosenbrock = solve(prob_dae, Rodas5P(); callback=cb_daef, reltol=1e-8, abstol=1e-10) +# 256.792 μs (672 allocations: 32.56 KiB) + +#TEST 6: VECTOR CALLBACK +function f!(du, u, p, t) + du[1] = -u[1] + du[2] = 0.2*u[1] - 0.1*u[2] +end + +u0 = [3.0, 0.0] +tspan = (0.0, 10.0) +prob = ODEProblem(f!, u0, tspan) + +# Two event surfaces: u[1] == 2.0 and u[1] == 1.0 +function condition!(out, u, t, integrator) + out[1] = u[1] - 2.0 + out[2] = u[1] - 1.0 +end + +# Discontinuous update to the state when an event fires +function affect!(integrator, idx) + if idx == 1 + # when u[1] crosses 2, kick u[2] up (jump discontinuity) + integrator.u[2] += 5.0 + elseif idx == 2 + # when u[1] crosses 1, reset u[2] + integrator.u[2] = 0.0 + end +end + +cb = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = true) +cb2 = VectorContinuousCallback(condition!, affect!, 2; is_discontinuity = false) + +sol_disco = solve(prob, RadauIIA5(); callback = cb) +# 49.125 μs (664 allocations: 32.89 KiB) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2) +# 37.375 μs (531 allocations: 25.23 KiB) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb) +# 57.333 μs (592 allocations: 31.23 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2) +# 44.250 μs (476 allocations: 23.73 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb) +# 37.833 μs (673 allocations: 31.80 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2) +# 24.958 μs (557 allocations: 24.23 KiB) + +#TEST 7 +function f!(du, u, p, t) + x1, x2 = u + du[1] = x2 + if x2 < 0.0 + du[2] = -x1 + 1.0 + else + du[2] = -x1 - 1.0 + end +end + +u = [1.5, 0.8] +tspan = (0.0, 2.0) +prob = ODEProblem(f!, u, tspan) + +cond(u, t, integrator) = u[2] +affect!(integrator) = nothing + +cb = ContinuousCallback(cond, affect!; is_discontinuity = true) +cb2 = ContinuousCallback(cond, affect!; is_discontinuity = false) + +sol_disco = solve(prob, RadauIIA5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +sol_no_disco = solve(prob, RadauIIA5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) + +sol_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 240.291 μs (1821 allocations: 71.56 KiB) +sol_no_disco_rosenbrock = solve(prob, Rodas5P(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 184.625 μs (1029 allocations: 49.23 KiB) + +sol_disco_tsit5 = solve(prob, Tsit5(); callback = cb, reltol = 1e-8, abstol = 1e-10) +# 79.791 μs (1678 allocations: 73.85 KiB) +sol_no_disco_tsit5 = solve(prob, Tsit5(); callback = cb2, reltol = 1e-8, abstol = 1e-10) +# 55.958 μs (1259 allocations: 57.04 KiB) \ No newline at end of file