Skip to content

Commit f9a6be1

Browse files
feat: [AI] also support passing a vector of time points
1 parent 2e6a489 commit f9a6be1

2 files changed

Lines changed: 83 additions & 5 deletions

File tree

lib/ModelingToolkitBase/test/analysis_points.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,35 @@ if @isdefined(ModelingToolkit)
220220
@test matrices_op.B matrices_ref.B
221221
@test matrices_op.C matrices_ref.C
222222
@test matrices_op.D matrices_ref.D
223+
224+
# Vector of time points: linearization_function is built once and reused.
225+
ts = [0.0, 0.5, 1.0]
226+
mats_vec, _, extras_vec = linearize(
227+
sys, sys.plant_input, sys.plant_output;
228+
op = ModelingToolkit.LinearizationOpPoint(sol, ts)
229+
)
230+
@test length(mats_vec) == 3
231+
@test length(extras_vec) == 3
232+
# The system is linear so all operating points yield the same A,B,C,D.
233+
for mats_t in mats_vec
234+
@test mats_t.A matrices_ref.A
235+
@test mats_t.B matrices_ref.B
236+
@test mats_t.C matrices_ref.C
237+
@test mats_t.D matrices_ref.D
238+
end
239+
# Two-arg form: linearize(ssys, lin_fun; op=LinearizationOpPoint(sol, ts))
240+
lin_fun, ssys_lin = linearization_function(sys, sys.plant_input, sys.plant_output)
241+
mats_vec2, extras_vec2 = linearize(
242+
ssys_lin, lin_fun;
243+
op = ModelingToolkit.LinearizationOpPoint(sol, ts)
244+
)
245+
@test length(mats_vec2) == 3
246+
for (m1, m2) in zip(mats_vec, mats_vec2)
247+
@test m1.A m2.A
248+
@test m1.B m2.B
249+
@test m1.C m2.C
250+
@test m1.D m2.D
251+
end
223252
end
224253

225254
@testset "Complicated model" begin

src/linearization.jl

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""
22
$(TYPEDEF)
33
4-
Wraps an `ODESolution` and a time point `t`. When passed as `op` to [`linearize`](@ref),
5-
an operating point is constructed from the values of differential state variables and
6-
parameters of `sol` evaluated at time `t`. Algebraic variables are not set and will be
7-
determined by the initialization algorithm.
4+
Wraps an `ODESolution` and a time point (or vector of time points) `t`. When passed as
5+
`op` to [`linearize`](@ref), an operating point is constructed from the values of
6+
differential state variables and parameters of `sol` evaluated at `t`. Algebraic
7+
variables are not set and will be determined by the initialization algorithm.
8+
9+
When `t` is an `AbstractVector`, [`linearize`](@ref) calls [`linearization_function`](@ref)
10+
once and evaluates the linearization at each time point, returning vectors of matrices and
11+
extras.
812
913
# Fields
1014
@@ -16,7 +20,7 @@ struct LinearizationOpPoint{S <: SciMLBase.AbstractODESolution, T}
1620
"""
1721
sol::S
1822
"""
19-
The time point at which to evaluate the solution.
23+
The time point (or vector of time points) at which to evaluate the solution.
2024
"""
2125
t::T
2226
end
@@ -37,6 +41,29 @@ function _build_op_from_solution(op::LinearizationOpPoint)
3741
return result
3842
end
3943

44+
function _build_op_from_solution(op::LinearizationOpPoint{S, <:AbstractVector}) where {S}
45+
sol_sys = MTKBase.indp_to_system(op.sol)
46+
eqs = equations(sol_sys)
47+
sts = unknowns(sol_sys)
48+
# Find differential equation indices and extract parameters once — both are
49+
# time-independent, so we only do this work once regardless of how many time
50+
# points are requested.
51+
diff_idxs = findall(isdiffeq, eqs)
52+
param_vals = Dict{SymbolicT, SymbolicT}()
53+
for p in parameters(sol_sys)
54+
param_vals[p] = getp(op.sol, p)(op.sol)
55+
end
56+
# Interpolate once per time point to build the per-point operating-point dict.
57+
return map(op.t) do ti
58+
u = op.sol(ti)
59+
result = copy(param_vals)
60+
for i in diff_idxs
61+
result[sts[i]] = u[i]
62+
end
63+
result
64+
end
65+
end
66+
4067
"""
4168
lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = nothing, kwargs...)
4269
@@ -809,6 +836,13 @@ function linearize(
809836
op = Dict(), allow_input_derivatives = false,
810837
p = DiffEqBase.NullParameters()
811838
)
839+
if op isa LinearizationOpPoint && op.t isa AbstractVector
840+
ops = _build_op_from_solution(op)
841+
results = map(zip(ops, op.t)) do (op_i, ti)
842+
linearize(sys, lin_fun; t = ti, op = op_i, allow_input_derivatives, p)
843+
end
844+
return first.(results), last.(results)
845+
end
812846
if op isa LinearizationOpPoint
813847
t = op.t
814848
op = _build_op_from_solution(op)
@@ -840,6 +874,21 @@ function linearize(
840874
zero_dummy_der = false,
841875
kwargs...
842876
)
877+
if op isa LinearizationOpPoint && op.t isa AbstractVector
878+
ops = _build_op_from_solution(op)
879+
ts = op.t
880+
# Build the linearization function once using the first operating point, then
881+
# reuse it for all subsequent time points — this avoids redundant `mtkcompile`
882+
# and Jacobian preparation work.
883+
lin_fun, ssys = linearization_function(
884+
sys, inputs, outputs;
885+
zero_dummy_der, op = ops[1], t = ts[1], kwargs...
886+
)
887+
results = map(zip(ops, ts)) do (op_i, ti)
888+
linearize(ssys, lin_fun; op = op_i, t = ti, allow_input_derivatives)
889+
end
890+
return first.(results), ssys, last.(results)
891+
end
843892
if op isa LinearizationOpPoint
844893
t = op.t
845894
op = _build_op_from_solution(op)

0 commit comments

Comments
 (0)