diff --git a/docs/src/api.md b/docs/src/api.md index 8e5c64886..9a90ec5cb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -331,7 +331,7 @@ get_num_produce set_num_produce!! increment_num_produce!! reset_num_produce!! -setorder! +setorder!! set_retained_vns_del! ``` @@ -358,7 +358,7 @@ DynamicPPL provides the following default accumulators. ```@docs LogPriorAccumulator LogLikelihoodAccumulator -NumProduceAccumulator +VariableOrderAccumulator ``` ### Common API diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 7527c8be2..d210d34e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -50,7 +50,7 @@ export AbstractVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, - NumProduceAccumulator, + VariableOrderAccumulator, push!!, empty!!, subset, @@ -73,7 +73,7 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - setorder!, + setorder!!, istrans, link, link!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 4917a4892..eb75662ca 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo) return vi end +""" + setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + +Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe +statements run before sampling `vn`. +""" +function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder)) +end + +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn] + # Variables and their realizations. @doc """ keys(vi::AbstractVarInfo) @@ -972,14 +990,22 @@ end Return the `num_produce` of `vi`. """ -get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num +get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce """ set_num_produce!!(vi::AbstractVarInfo, n::Int) Set the `num_produce` field of `vi` to `n`. """ -set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) +function set_num_produce!!(vi::AbstractVarInfo, n::Integer) + if hasacc(vi, Val(:VariableOrder)) + acc = getacc(vi, Val(:VariableOrder)) + acc = VariableOrderAccumulator(n, acc.order) + else + acc = VariableOrderAccumulator(n) + end + return setacc!!(vi, acc) +end """ increment_num_produce!!(vi::AbstractVarInfo) @@ -987,14 +1013,14 @@ set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumula Add 1 to `num_produce` in `vi`. """ increment_num_produce!!(vi::AbstractVarInfo) = - map_accumulator!!(increment, vi, Val(:NumProduce)) + map_accumulator!!(increment, vi, Val(:VariableOrder)) """ reset_num_produce!!(vi::AbstractVarInfo) Reset the value of `num_produce` in `vi` to 0. """ -reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) +reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi))) """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl index 10a988ae5..0b5fb09ec 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` - `accumulate_observe!!(acc::T, right, left, vn)` - `accumulate_assume!!(acc::T, val, logjac, vn, right)` +- `Base.copy(acc::T)` To be able to work with multi-threading, it should also implement: - `split(acc::T)` @@ -136,6 +137,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} @inline return haskey(at.nt, accname) end Base.keys(at::AccumulatorTuple) = keys(at.nt) +Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt +Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h) +Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt)) function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} return AccumulatorTuple(convert(T, accs.nt)) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b92e49fba..5aa55c3dc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -184,7 +184,6 @@ function assume( f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! vi = BangBang.setindex!!(vi, f(r), vn) - setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. r = vi[vn, dist] diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index ab538ba51..e5f3f6163 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -41,25 +41,40 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T) LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() """ - NumProduceAccumulator{T} <: AbstractAccumulator + VariableOrderAccumulator{T} <: AbstractAccumulator -An accumulator that tracks the number of observations during model execution. +An accumulator that tracks the order of variables in a `VarInfo`. + +This doesn't track the full ordering, but rather how many observations have taken place +before the assume statement for each variable. This is needed for particle methods, where +the model is segmented into parts by each observation, and we need to know which part each +assume statement is in. # Fields $(TYPEDFIELDS) """ -struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator +struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator "the number of observations" - num::T + num_produce::Eltype + "mapping of variable names to their order in the model" + order::Dict{VNType,Eltype} end """ - NumProduceAccumulator{T<:Integer}() + VariableOrderAccumulator{T<:Integer}(n=zero(T)) -Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. +Create a new `VariableOrderAccumulator` accumulator with the number of observations set to n """ -NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) -NumProduceAccumulator() = NumProduceAccumulator{Int}() +VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = + VariableOrderAccumulator(convert(T, n), Dict{VarName,T}()) +VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) +VariableOrderAccumulator() = VariableOrderAccumulator{Int}() + +Base.copy(acc::LogPriorAccumulator) = acc +Base.copy(acc::LogLikelihoodAccumulator) = acc +function Base.copy(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) +end function Base.show(io::IO, acc::LogPriorAccumulator) return print(io, "LogPriorAccumulator($(repr(acc.logp)))") @@ -67,17 +82,48 @@ end function Base.show(io::IO, acc::LogLikelihoodAccumulator) return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") end -function Base.show(io::IO, acc::NumProduceAccumulator) - return print(io, "NumProduceAccumulator($(repr(acc.num)))") +function Base.show(io::IO, acc::VariableOrderAccumulator) + return print( + io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" + ) +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp +function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return acc1.logp == acc2.logp +end +function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order +end + +function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return isequal(acc1.logp, acc2.logp) +end +function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return isequal(acc1.logp, acc2.logp) +end +function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) +end + +Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) +function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) + return hash((LogLikelihoodAccumulator, acc.logp), h) +end +function Base.hash(acc::VariableOrderAccumulator, h::UInt) + return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) end accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood -accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce +accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) -split(acc::NumProduceAccumulator) = acc +split(acc::VariableOrderAccumulator) = acc function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc.logp + acc2.logp) @@ -85,8 +131,12 @@ end function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc.logp + acc2.logp) end -function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) - return NumProduceAccumulator(max(acc.num, acc2.num)) +function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + # Note that assumptions are not allowed within in parallelised blocks, and thus the + # dictionaries should be identical. + return VariableOrderAccumulator( + max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order) + ) end function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) @@ -95,11 +145,12 @@ end function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc1.logp + acc2.logp) end -increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) +function increment(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) +end Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) -Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) return acc + LogPriorAccumulator(logpdf(right, val) + logjac) @@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) end -accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) +function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) + acc.order[vn] = acc.num_produce + return acc +end +accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) @@ -126,15 +180,19 @@ function Base.convert( return LogLikelihoodAccumulator(convert(T, acc.logp)) end function Base.convert( - ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator -) where {T} - return NumProduceAccumulator(convert(T, acc.num)) + ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator +) where {ElType,VnType} + order = Dict{VnType,ElType}() + for (k, v) in acc.order + order[convert(VnType, k)] = convert(ElType, v) + end + return VariableOrderAccumulator(convert(ElType, acc.num_produce), order) end # TODO(mhauru) -# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on +# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to -# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is +# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) @@ -149,6 +207,6 @@ function default_accumulators( return AccumulatorTuple( LogPriorAccumulator{FloatT}(), LogLikelihoodAccumulator{FloatT}(), - NumProduceAccumulator{IntT}(), + VariableOrderAccumulator{IntT}(), ) end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 9047c9f0a..5f7e0dd52 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -4,6 +4,10 @@ end PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) +function Base.copy(acc::PriorDistributionAccumulator) + return PriorDistributionAccumulator(copy(acc.priors)) +end + accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index b6b97c8f9..47b882b46 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -31,6 +31,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end +function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) +end + function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) logps = acc.logps # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 42fcedfb8..d33eaa5e8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,7 +125,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -133,7 +133,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -198,6 +198,12 @@ struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformati transformation::C end +function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) + return vi1.values == vi2.values && + vi1.accs == vi2.accs && + vi1.transformation == vi2.transformation +end + transformation(vi::SimpleVarInfo) = vi.transformation function SimpleVarInfo(values, accs) @@ -242,7 +248,7 @@ end # Constructor from `VarInfo`. function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} values = values_as(vi, D) - return SimpleVarInfo(values, deepcopy(getaccs(vi))) + return SimpleVarInfo(values, copy(getaccs(vi))) end function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) @@ -441,7 +447,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = deepcopy(getaccs(varinfo_right)) + accs = copy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 7d2d768a6..83f26bcf3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -78,7 +78,9 @@ end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) +function setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) + return ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread) +end setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4d6225c10..964474a09 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -20,6 +20,10 @@ function ValuesAsInModelAccumulator(include_colon_eq) return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) end +function Base.copy(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq) +end + accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel function split(acc::ValuesAsInModelAccumulator) diff --git a/src/varinfo.jl b/src/varinfo.jl index 6a968da4d..f9dd9da8d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,10 +15,9 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. + `md.vns`, `md.ranges` `md.dists`, and `md.flags`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. - `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the @@ -57,13 +56,21 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Number of `observe` statements before each random variable is sampled - orders::Vector{Int} - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` flags::Dict{String,BitVector} end +function Base.:(==)(md1::Metadata, md2::Metadata) + return ( + md1.idcs == md2.idcs && + md1.vns == md2.vns && + md1.ranges == md2.ranges && + md1.vals == md2.vals && + md1.dists == md2.dists && + md1.flags == md2.flags + ) +end + ########### # VarInfo # ########### @@ -159,6 +166,10 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +function Base.:(==)(vi1::VarInfo, vi2::VarInfo) + return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) +end + # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. @@ -262,8 +273,6 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) # New flags sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) @@ -281,13 +290,11 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), + Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags), ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, copy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -348,7 +355,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -391,12 +398,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, copy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) nt = NamedTuple(new_metas) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -448,8 +455,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), - deepcopy(getaccs(vi)), + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) return VarInfo(md, accs) end @@ -472,7 +478,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -498,7 +504,6 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - Vector{Int}(), flags, ) end @@ -516,7 +521,6 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.orders) for k in keys(meta.flags) empty!(meta.flags[k]) end @@ -533,7 +537,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.accs)) + return VarInfo(metadata, copy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -605,15 +609,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va end flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) - return Metadata( - indices, - vns, - ranges, - vals, - metadata.dists[indices_for_vns], - metadata.orders[indices_for_vns], - flags, - ) + return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -622,7 +618,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, deepcopy(varinfo_right.accs)) + return VarInfo(metadata, copy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -683,7 +679,6 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) @@ -705,13 +700,12 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - push!(orders, getorder(metadata_for_vn, vn)) for k in keys(flags) push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end - return Metadata(idcs, vns, ranges, vals, dists, orders, flags) + return Metadata(idcs, vns, ranges, vals, dists, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -1375,7 +1369,6 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, ), cumulative_logjac @@ -1541,7 +1534,6 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, ), cumulative_logjac @@ -1723,7 +1715,6 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) ("VarNames", vi.metadata.vns), ("Range", vi.metadata.ranges), ("Vals", vi.metadata.vals), - ("Orders", vi.metadata.orders), ] for accname in acckeys(vi) push!(lines, (string(accname), getacc(vi, Val(accname)))) @@ -1808,13 +1799,12 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) [1:length(val)], val, [dist], - [get_num_produce(vi)], Dict{String,BitVector}("trans" => [false], "del" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, get_num_produce(vi)) + push!(meta, vn, r, dist) end return vi @@ -1834,7 +1824,7 @@ end # exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. -function Base.push!(meta::Metadata, vn, r, dist, num_produce) +function Base.push!(meta::Metadata, vn, r, dist) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) @@ -1843,7 +1833,6 @@ function Base.push!(meta::Metadata, vn, r, dist, num_produce) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) return meta @@ -1854,31 +1843,6 @@ function Base.delete!(vi::VarInfo, vn::VarName) return vi end -""" - setorder!(vi::VarInfo, vn::VarName, index::Int) - -Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe -statements run before sampling `vn`. -""" -function setorder!(vi::VarInfo, vn::VarName, index::Int) - setorder!(getmetadata(vi, vn), vn, index) - return vi -end -function setorder!(metadata::Metadata, vn::VarName, index::Int) - metadata.orders[metadata.idcs[vn]] = index - return metadata -end -setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv - -""" - getorder(vi::VarInfo, vn::VarName) - -Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements -run before sampling `vn`. -""" -getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) -getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### diff --git a/test/accumulators.jl b/test/accumulators.jl index 36bb95e46..5963ad8b5 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -7,7 +7,7 @@ using DynamicPPL: AccumulatorTuple, LogLikelihoodAccumulator, LogPriorAccumulator, - NumProduceAccumulator, + VariableOrderAccumulator, accumulate_assume!!, accumulate_observe!!, combine, @@ -31,11 +31,11 @@ using DynamicPPL: LogLikelihoodAccumulator{Float64}() == LogLikelihoodAccumulator{Float64}(0.0) == zero(LogLikelihoodAccumulator(1.0)) - @test NumProduceAccumulator(0) == - NumProduceAccumulator() == - NumProduceAccumulator{Int}() == - NumProduceAccumulator{Int}(0) == - zero(NumProduceAccumulator(1)) + @test VariableOrderAccumulator(0) == + VariableOrderAccumulator() == + VariableOrderAccumulator{Int}() == + VariableOrderAccumulator{Int}(0) == + VariableOrderAccumulator(0, Dict{VarName,Int}()) end @testset "addition and incrementation" begin @@ -47,19 +47,19 @@ using DynamicPPL: LogLikelihoodAccumulator(2.0f0) @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == LogLikelihoodAccumulator(2.0) - @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) - @test increment(NumProduceAccumulator{UInt8}()) == - NumProduceAccumulator{UInt8}(1) + @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) + @test increment(VariableOrderAccumulator{UInt8}()) == + VariableOrderAccumulator{UInt8}(1) end @testset "split and combine" begin for acc in [ LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0), - NumProduceAccumulator(1), + VariableOrderAccumulator(1), LogPriorAccumulator(1.0f0), LogLikelihoodAccumulator(1.0f0), - NumProduceAccumulator(UInt8(1)), + VariableOrderAccumulator(UInt8(1)), ] @test combine(acc, split(acc)) == acc end @@ -71,8 +71,9 @@ using DynamicPPL: @test convert( LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) ) == LogLikelihoodAccumulator{Float32}(1.0f0) - @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == - NumProduceAccumulator{UInt8}(1) + @test convert( + VariableOrderAccumulator{UInt8,VarName}, VariableOrderAccumulator(1) + ) == VariableOrderAccumulator{UInt8}(1) @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == LogPriorAccumulator{Float32}(1.0f0) @@ -90,8 +91,8 @@ using DynamicPPL: @test accumulate_assume!!( LogLikelihoodAccumulator(1.0), val, logjac, vn, dist ) == LogLikelihoodAccumulator(1.0) - @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == - NumProduceAccumulator(1) + @test accumulate_assume!!(VariableOrderAccumulator(1), val, logjac, vn, dist) == + VariableOrderAccumulator(1, Dict{VarName,Int}((vn => 1))) end @testset "accumulate_observe" begin @@ -102,8 +103,8 @@ using DynamicPPL: LogPriorAccumulator(1.0) @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) - @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == - NumProduceAccumulator(2) + @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == + VariableOrderAccumulator(2) end end @@ -113,7 +114,7 @@ using DynamicPPL: lp_f32 = LogPriorAccumulator(1.0f0) ll_f64 = LogLikelihoodAccumulator(1.0) ll_f32 = LogLikelihoodAccumulator(1.0f0) - np_i64 = NumProduceAccumulator(1) + np_i64 = VariableOrderAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -131,12 +132,12 @@ using DynamicPPL: @test at_all64[:LogPrior] == lp_f64 @test at_all64[:LogLikelihood] == ll_f64 - @test at_all64[:NumProduce] == np_i64 + @test at_all64[:VariableOrder] == np_i64 - @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder)) @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 - @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) @test collect(at_all64) == [lp_f64, ll_f64, np_i64] # Replace the existing LogPriorAccumulator diff --git a/test/varinfo.jl b/test/varinfo.jl index 1c597f951..27ede6af8 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -21,12 +21,13 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) push!!(vi, vn, r, dist) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") r = rand(dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r else vi[vn] @@ -54,7 +55,6 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - @test meta.orders[ind] == fmeta.orders[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -213,7 +213,7 @@ end @test_throws "has no field LogLikelihood" getloglikelihood(vi) @test_throws "has no field LogLikelihood" getlogp(vi) @test_throws "has no field LogLikelihood" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) + @test_throws "has no field VariableOrder" get_num_produce(vi) @test begin vi = acclogprior!!(vi, 1.0) getlogprior(vi) == lp_a + lp_b + 1.0 @@ -225,7 +225,7 @@ end vi = last( DynamicPPL.evaluate!!( - m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + m, DynamicPPL.setaccs!!(deepcopy(vi), (VariableOrderAccumulator(),)) ), ) @test_throws "has no field LogPrior" getlogprior(vi) @@ -240,8 +240,8 @@ end @test_throws "has no field LogLikelihood" getloglikelihood(vi) @test_throws "has no field LogPrior" getlogp(vi) @test_throws "has no field LogPrior" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) - @test_throws "has no field NumProduce" reset_num_produce!!(vi) + @test_throws "has no field VariableOrder" get_num_produce(vi) + @test_throws "has no field VariableOrder" reset_num_produce!!(vi) end @testset "flags" begin @@ -1089,7 +1089,12 @@ end randr(vi, vn_a2, dists[2]) vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 @test DynamicPPL.get_num_produce(vi) == 3 vi = DynamicPPL.reset_num_produce!!(vi) @@ -1108,7 +1113,12 @@ end vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a2) == 3 @test DynamicPPL.get_num_produce(vi) == 3 vi = empty!!(DynamicPPL.typed_varinfo(vi)) @@ -1123,9 +1133,12 @@ end randr(vi, vn_a2, dists[2]) vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 vi = DynamicPPL.reset_num_produce!!(vi) @@ -1144,9 +1157,12 @@ end vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 3 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 end