Skip to content

Commit b3183de

Browse files
feat: add pass which removes equations incident on a single variable
1 parent 9797c14 commit b3183de

3 files changed

Lines changed: 96 additions & 0 deletions

File tree

src/systems/alias_elimination.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,91 @@ function find_perfect_aliases!(
148148
return aliases
149149
end
150150

151+
"""
152+
$TYPEDSIGNATURES
153+
154+
Analytically remove any equations in `state` which are only incident on a single variable.
155+
"""
156+
function remove_constant_variables!(state::TearingState; allow_parameter::Bool = true, kwargs...)
157+
StateSelection.complete!(state.structure)
158+
(; additional_observed, original_eqs, fullvars, structure, sys) = state
159+
(; graph, var_to_diff) = structure
160+
eqs = collect(equations(state))
161+
eqs_to_rm = Int[]
162+
vars_to_rm = Int[]
163+
vars_to_rm_set = BitSet()
164+
eqs_to_rm_set = BitSet()
165+
fullvars_set = Set{SymbolicT}(fullvars)
166+
subs = Dict{SymbolicT, SymbolicT}()
167+
eqs_to_substitute = Set{Int}()
168+
param_der_subber = SU.Substituter{false}(state.param_derivative_map)
169+
170+
snbors = Int[]
171+
for _ in 1:4
172+
removed_eq = false
173+
for ieq in 𝑠vertices(graph)
174+
empty!(snbors)
175+
append!(snbors, 𝑠neighbors(graph, ieq))
176+
setdiff!(snbors, vars_to_rm_set)
177+
length(snbors) == 1 || continue
178+
ivar = first(snbors)
179+
eq = eqs[ieq]
180+
var = fullvars[ivar]
181+
lex = Symbolics.LinearExpander(var; strict = true)
182+
a, b, islin = lex(eq.rhs)
183+
islin || continue
184+
# `allow_symbolic = true` since we know this equation (directly or indirectly) only
185+
# depends on `var`. Any variables present in it can only be ones we've already
186+
# eliminated in `vars_to_rm`.
187+
if !MTKTearing._check_allow_symbolic_parameter(
188+
state, a, true, allow_parameter; fullvars_set
189+
)
190+
continue
191+
end
192+
removed_eq = true
193+
push!(eqs_to_rm, ieq)
194+
push!(eqs_to_rm_set, ieq)
195+
push!(vars_to_rm, ivar)
196+
push!(vars_to_rm_set, ivar)
197+
# `a` typically is faster to negate, since it usually is a constant or small expression
198+
rhs = b / -a
199+
push!(additional_observed, var ~ rhs)
200+
subs[var] = rhs
201+
union!(eqs_to_substitute, 𝑑neighbors(graph, ivar))
202+
203+
v = var_to_diff[ivar]
204+
while v isa Int
205+
# We're only looking at equations incident on a single variable, so `rhs` will only ever
206+
# involve parameters. The derivative is thus going to be zero.
207+
rhs = param_der_subber(Symbolics.derivative(rhs, get_iv(sys)::SymbolicT; throw_no_derivative = true))
208+
subs[fullvars[v]] = rhs
209+
push!(additional_observed, default_toterm(fullvars[v]) ~ rhs)
210+
union!(eqs_to_substitute, 𝑑neighbors(graph, v))
211+
push!(vars_to_rm, v)
212+
push!(vars_to_rm_set, v)
213+
v = var_to_diff[v]
214+
end
215+
end
216+
217+
removed_eq || break
218+
end
219+
220+
subber = SU.Substituter{false}(subs)
221+
for ieq in eqs_to_substitute
222+
ieq in eqs_to_rm_set && continue
223+
eqs[ieq] = subber(eqs[ieq])
224+
original_eqs[ieq] = subber(original_eqs[ieq])
225+
end
226+
227+
@set! sys.eqs = eqs
228+
state.sys = sys
229+
old_to_new_eq, old_to_new_var = StateSelection.rm_eqs_vars!(
230+
state, eqs_to_rm, vars_to_rm
231+
)
232+
233+
return length(eqs_to_rm)
234+
end
235+
151236
function alias_elimination!(state::TearingState; fully_determined = true,
152237
print_underconstrained_variables = false, kwargs...)
153238
StateSelection.complete!(state.structure)

src/systems/systemstructure.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ function _mtkcompile!(
239239
state = ModelingToolkit.inputs_to_parameters!(state, discrete_inputs, OrderedSet{SymbolicT}())
240240
state = ModelingToolkit.inputs_to_parameters!(state, inputs, outputs)
241241
eliminate_perfect_aliases!(state)
242+
remove_constant_variables!(state; kwargs...)
242243
StateSelection.trivial_tearing!(state)
243244
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
244245
if check_consistency

test/reduction.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,13 @@ ss = mtkcompile(sys)
353353
@mtkcompile sys = System([D(x) ~ 2x, y ~ x], t; state_priorities = [y => 10])
354354
@test isequal(only(unknowns(sys)), y)
355355
end
356+
357+
@testset "Constant equations are removed" begin
358+
@variables x(t) y(t) z(t)
359+
@named sys = System([0 ~ 2x + 3t + 4, 0 ~ x * y + 2, 0 ~ D(x) + D(z) + 2z], t)
360+
ts = TearingState(sys)
361+
ModelingToolkit.remove_constant_variables!(ts)
362+
dx = ModelingToolkit.default_toterm(unwrap(D(x)))
363+
@test isequal(ts.additional_observed, [x ~ (3t + 4) / -2, dx ~ (-3//2), y ~ 2 / (-x)])
364+
@test isequal(equations(ts), [0 ~ D(z) + 2z - 3/2])
365+
end

0 commit comments

Comments
 (0)