Skip to content

feat: more robust inputs/outputs handling #3795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Symbolics: get_variables
Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref)
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref)
"""
inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))]
inputs(sys) = collect(get_inputs(sys))

"""
outputs(sys)
Expand All @@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@
See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
"""
function outputs(sys)
o = observed(sys)
rhss = [eq.rhs for eq in o]
lhss = [eq.lhs for eq in o]
unique([filter(isoutput, unknowns(sys))
filter(isoutput, parameters(sys))
filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
filter(x -> iscall(x) && isoutput(x), lhss)])
return collect(get_outputs(sys))
end

"""
Expand Down Expand Up @@ -316,6 +310,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
ps = parameters(sys)

@set! sys.ps = [ps; new_parameters]
@set! sys.inputs = Set{BasicSymbolic}(filter(isinput, fullvars))
@set! sys.outputs = Set{BasicSymbolic}(filter(isoutput, fullvars))
@set! state.sys = sys
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
@set! state.structure = structure
Expand Down
2 changes: 2 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ const SYS_PROPS = [:eqs
:parent
:is_dde
:tstops
:inputs
:outputs
:index_cache
:isscheduled
:costs
Expand Down
37 changes: 32 additions & 5 deletions src/systems/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ struct System <: AbstractSystem
"""
tstops::Vector{Any}
"""
$INTERNAL_FIELD_WARNING
The list of input variables of the system.
"""
inputs::Set{BasicSymbolic}
"""
$INTERNAL_FIELD_WARNING
The list of output variables of the system.
"""
outputs::Set{BasicSymbolic}
"""
The `TearingState` of the system post-simplification with `mtkcompile`.
"""
tearing_state::Any
Expand Down Expand Up @@ -255,8 +265,9 @@ struct System <: AbstractSystem
brownians, iv, observed, parameter_dependencies, var_to_name, name, description,
defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events,
connector_type, assertions = Dict{BasicSymbolic, String}(),
metadata = MetadataT(), gui_metadata = nothing,
is_dde = false, tstops = [], tearing_state = nothing, namespacing = true,
metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [],
inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(),
tearing_state = nothing, namespacing = true,
complete = false, index_cache = nothing, ignored_connections = nothing,
preface = nothing, parent = nothing, initializesystem = nothing,
is_initializesystem = false, is_discrete = false, isscheduled = false,
Expand Down Expand Up @@ -296,7 +307,8 @@ struct System <: AbstractSystem
observed, parameter_dependencies, var_to_name, name, description, defaults,
guesses, systems, initialization_eqs, continuous_events, discrete_events,
connector_type, assertions, metadata, gui_metadata, is_dde,
tstops, tearing_state, namespacing, complete, index_cache, ignored_connections,
tstops, inputs, outputs, tearing_state, namespacing,
complete, index_cache, ignored_connections,
preface, parent, initializesystem, is_initializesystem, is_discrete,
isscheduled, schedule)
end
Expand Down Expand Up @@ -367,15 +379,27 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];

defaults = anydict(defaults)
guesses = anydict(guesses)
inputs = Set{BasicSymbolic}()
outputs = Set{BasicSymbolic}()
var_to_name = anydict()

let defaults = discover_from_metadata ? defaults : Dict(),
guesses = discover_from_metadata ? guesses : Dict()
guesses = discover_from_metadata ? guesses : Dict(),
inputs = discover_from_metadata ? inputs : Set(),
outputs = discover_from_metadata ? outputs : Set()

process_variables!(var_to_name, defaults, guesses, dvs)
process_variables!(var_to_name, defaults, guesses, ps)
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed])
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed])

for var in dvs
if isinput(var)
push!(inputs, var)
elseif isoutput(var)
push!(outputs, var)
end
end
end
filter!(!(isnothing ∘ last), defaults)
filter!(!(isnothing ∘ last), guesses)
Expand Down Expand Up @@ -416,7 +440,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde,
tstops, tearing_state, true, false, nothing, ignored_connections, preface, parent,
tstops, inputs, outputs, tearing_state, true, false,
nothing, ignored_connections, preface, parent,
initializesystem, is_initializesystem, is_discrete; checks)
end

Expand Down Expand Up @@ -1141,6 +1166,8 @@ function Base.isapprox(sysa::System, sysb::System)
isequal(get_metadata(sysa), get_metadata(sysb)) &&
isequal(get_is_dde(sysa), get_is_dde(sysb)) &&
issetequal(get_tstops(sysa), get_tstops(sysb)) &&
issetequal(get_inputs(sysa), get_inputs(sysb)) &&
issetequal(get_outputs(sysa), get_outputs(sysb)) &&
safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) &&
isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) &&
isequal(get_is_discrete(sysa), get_is_discrete(sysb)) &&
Expand Down
17 changes: 17 additions & 0 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,20 @@ end
x = [1.0]
@test_nowarn f[1](x, u, p, 0.0)
end

@testset "Observed inputs and outputs" begin
@variables x(t) y(t) [input = true] z(t) [output = true]
eqs = [D(x) ~ x + y + z
y ~ z]
@named sys = System(eqs, t)
@test issetequal(ModelingToolkit.inputs(sys), [y])
@test issetequal(ModelingToolkit.outputs(sys), [z])

ss1 = mtkcompile(sys, inputs = [y], outputs = [z])
@test issetequal(ModelingToolkit.inputs(ss1), [y])
@test issetequal(ModelingToolkit.outputs(ss1), [z])

ss2 = mtkcompile(sys, inputs = [z], outputs = [y])
@test issetequal(ModelingToolkit.inputs(ss2), [z])
@test issetequal(ModelingToolkit.outputs(ss2), [y])
end
Loading