Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType,
_process_AD_choice,
_ode_interpolant, _ode_interpolant!, has_stiff_interpolation,
_ode_addsteps!, DerivativeOrderNotPossibleError
_ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity
using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache

using TruncatedStacktraces: @truncate_stacktrace
Expand Down
7 changes: 7 additions & 0 deletions lib/OrdinaryDiffEqBDF/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ function bdf_step_reject_controller!(integrator, EEst1)
h = integrator.dt
integrator.cache.consfailcnt += 1
integrator.cache.nconsteps = 0

disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end

if integrator.cache.consfailcnt > 1
h = h / 2
end
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ include("cache_utils.jl")
include("initialize_dae.jl")

include("perform_step/composite_perform_step.jl")
include("disco.jl")

include("dense/generic_dense.jl")

Expand Down
71 changes: 71 additions & 0 deletions lib/OrdinaryDiffEqCore/src/disco.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
function set_discontinuity(u, uprev, integrator, cache)
breakpointθ = find_discontinuity(u, uprev, integrator, cache)
dt = integrator.dt
t = integrator.t
if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0
#println("Discontinuity detected at t = ", t + breakpointθ * dt)
return breakpointθ * dt
end
return -1
end

function find_discontinuity(u, uprev, integrator, cache)
cb = integrator.opts.callback
cb === nothing && return -1
isempty(cb.continuous_callbacks) && return -1
p = integrator.p
t = integrator.t
dt = integrator.dt
save_idxs = integrator.opts.save_idxs
k = integrator.k
cache = integrator.cache
differential_vars = integrator.differential_vars
θlo = zero(dt)
θhi = one(dt)
bracket = [θlo, θhi]
breakpointθ = -one(dt)
idx = 1
for i in cb.continuous_callbacks
if (!(i.is_discontinuity))
continue
end
disco_prob = integrator.disco_probs[idx]
disco_zero = disco_prob.f.f
disco_zero.dt = dt
disco_zero.uprev = uprev
disco_zero.u = u
disco_zero.k = k
disco_zero.cache = cache
disco_zero.differential_vars = differential_vars
disco_zero.idxs = save_idxs
if (i isa VectorContinuousCallback)
len_cb = i.len
out_prev = similar(u)
out_curr = similar(u)
i.condition(out_prev, uprev, t, integrator)
i.condition(out_curr, u, t + dt, integrator)
for j in 1:len_cb
if (out_prev[j] * out_curr[j] < zero(out_prev[j]))
disco_zero.ind = j
sol = solve(disco_prob; bracket = bracket)
tmp = sol[]
if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ))
breakpointθ = tmp
end
end
end
else
out_prev = i.condition(uprev, t, integrator)
out_curr = i.condition(u, t + dt, integrator)
if (out_prev * out_curr < zero(out_prev))
sol = solve(disco_prob; bracket = bracket)
tmp = sol[]
if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ))
breakpointθ = tmp
end
end
end
idx += 1
end
breakpointθ
end
41 changes: 41 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ end

function step_reject_controller!(integrator, controller::IController, alg)
(; qold) = integrator
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt = qold
end

Expand Down Expand Up @@ -271,6 +276,11 @@ end

function step_reject_controller!(integrator, cache::IControllerCache, alg)
@assert cache.dtreject ≈ integrator.qold "Controller cache went out of sync with time stepping logic."
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt = cache.dtreject # TODO this does not look right.
end

Expand Down Expand Up @@ -351,6 +361,11 @@ end
function step_reject_controller!(integrator, controller::PIController, alg)
(; q11) = integrator
(; qmin, gamma) = integrator.opts
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt /= min(inv(qmin), q11 / gamma)
end

Expand Down Expand Up @@ -457,6 +472,11 @@ end
function step_reject_controller!(integrator, cache::PIControllerCache, alg)
(; controller, q11) = cache
(; qmin, gamma) = controller
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt /= min(inv(qmin), q11 / gamma)
end

Expand Down Expand Up @@ -633,6 +653,11 @@ function step_accept_controller!(integrator, controller::PIDController, alg, dt_
end

function step_reject_controller!(integrator, controller::PIDController, alg)
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt *= integrator.qold
end

Expand Down Expand Up @@ -764,6 +789,11 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_
end

function step_reject_controller!(integrator, cache::PIDControllerCache, alg)
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt *= cache.dt_factor
end

Expand Down Expand Up @@ -877,6 +907,12 @@ end

function step_reject_controller!(integrator, controller::PredictiveController, alg)
(; dt, success_iter, qold) = integrator
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end

return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
end

Expand Down Expand Up @@ -988,6 +1024,11 @@ end
function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg)
(; dt, success_iter) = integrator
(; qold) = cache
if (integrator.disco_dt_set)
println("using fixed dt from discontinuity handling")
integrator.disco_dt_set = false
return integrator.dt
end
return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ end
end
end

function u_modified!(integrator::ODEIntegrator, bool::Bool)
function SciMLBase.u_modified!(integrator::ODEIntegrator, bool::Bool)
return integrator.u_modified = bool
end

Expand Down
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ mutable struct ODEIntegrator{
dtcache::tType
dtchangeable::Bool
dtpropose::tType
disco_dt_set::Bool
tdir::tdirType
eigen_est::eigenType
EEst::EEstT
Expand Down Expand Up @@ -191,6 +192,8 @@ mutable struct ODEIntegrator{
fsalfirst::FSALType
fsallast::FSALType
rng::RNGType
#disco_prob::IntervalNonlinearProblem
disco_probs::Vector{IntervalNonlinearProblem}
W::WType
P::PType
sqdt::SqdtType
Expand Down
52 changes: 49 additions & 3 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) =
determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts))))
determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere

mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType}
#integrator_ref::IntegratorType
u₁::uType
callback::callbackType
dt::tType
uprev::uType
u::uType
k::kType
cache::CacheType
idxs::idxsType
differential_vars::varsType
ind::Int
out::outType
end

function (z::zero_func_struct)(θ, p)
ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars)
return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind)
end

@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z)
@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind)
cb.condition(out, u, t, z)
return out[ind]
end

function SciMLBase.__init(
prob::Union{
SciMLBase.AbstractODEProblem,
Expand Down Expand Up @@ -57,6 +83,7 @@ function _ode_init(
save_everystep = isempty(saveat),
save_on = true,
save_discretes = true,
disco_dt_set = false,
save_start = save_everystep || isempty(saveat) ||
saveat isa Number || prob.tspan[1] in saveat,
save_end = nothing,
Expand Down Expand Up @@ -105,6 +132,7 @@ function _ode_init(
alias = ODEAliasSpecifier(),
initializealg = DefaultInit(),
rng = nothing,
disco_probs = nothing,
# SDE/RODE fields: accepted here so that SDE packages can delegate to
# _ode_init and construct an ODEIntegrator with noise populated.
save_noise = false,
Expand Down Expand Up @@ -733,6 +761,24 @@ function _ode_init(

_rng = rng === nothing ? Random.default_rng() : rng

num_probs = 0
for i in callbacks_internal.continuous_callbacks
if i.is_discontinuity
num_probs += 1
end
end
disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs)
idx = 1
for i in callbacks_internal.continuous_callbacks
if i.is_discontinuity
u₁ = similar(u)
out = i isa VectorContinuousCallback ? similar(u) : nothing
zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out)
disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p)
idx += 1
end
end

integrator = ODEIntegrator{
typeof(_alg), isinplace(prob), uType, typeof(du),
tType, typeof(p),
Expand All @@ -744,12 +790,12 @@ function _ode_init(
typeof(initializealg), typeof(differential_vars),
typeof(controller_cache), typeof(_rng),
typeof(W), typeof(P), typeof(sqdt),
typeof(noise), typeof(c), typeof(rate_constants),
typeof(noise), typeof(c), typeof(rate_constants)
}(
sol, u, du, k, t, tType(_dt), f, p,
uprev, uprev2, duprev, tprev,
_alg, dtcache, dtchangeable,
dtpropose, tdir, eigen_est, EEst,
dtpropose, disco_dt_set, tdir, eigen_est, EEst,
# TODO vvv remove these
QT(qoldinit), q11,
erracc, dtacc,
Expand All @@ -767,7 +813,7 @@ function _ode_init(
isout, reeval_fsal,
u_modified, reinitialize, isdae,
opts, stats, initializealg, differential_vars,
fsalfirst, fsallast, _rng,
fsalfirst, fsallast, _rng, disco_probs,
W, P, sqdt,
noise, c, rate_constants, QT(1)
)
Expand Down
Loading