diff --git a/Project.toml b/Project.toml index 91a81092..0461606f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,12 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "6.0.7" +version = "7.0.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" @@ -29,6 +30,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5" AxisArrays = "0.4.4" +DataAPI = "1.16.0" Dates = "<0.0.1, 1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" IteratorInterfaceExtensions = "0.1.1, 1" diff --git a/docs/Project.toml b/docs/Project.toml index 6fbefc4e..bbe55cf0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ CategoricalArrays = "0.8, 0.9, 0.10" DataFrames = "0.22, 1" Documenter = "0.26, 0.27, 1" Gadfly = "1.3.4" -MCMCChains = "6" +MCMCChains = "7" MLJBase = "0.19, 0.20, 0.21, 1" MLJDecisionTreeInterface = "0.3, 0.4" StatsPlots = "0.14, 0.15" diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index dc7b8784..0592a501 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -8,8 +8,18 @@ using Distributions using RecipesBase using Dates using KernelDensity: kde, pdf -import StatsBase: autocov, counts, sem, AbstractWeights, - autocor, describe, quantile, sample, summarystats, cov +import DataAPI +import StatsBase: + autocov, + counts, + sem, + AbstractWeights, + autocor, + describe, + quantile, + sample, + summarystats, + cov import MCMCDiagnosticTools import MLJModelInterface @@ -34,9 +44,21 @@ export ChainDataFrame export summarize # Reexport diagnostics functions -using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod, - BDAAutocovMethod, gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, mcse, - rafterydiag, rhat, rstar +using MCMCDiagnosticTools: + discretediag, + ess, + ess_rhat, + AutocovMethod, + FFTAutocovMethod, + BDAAutocovMethod, + gelmandiag, + gelmandiag_multivariate, + gewekediag, + heideldiag, + mcse, + rafterydiag, + rhat, + rstar export discretediag export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod export gelmandiag, gelmandiag_multivariate @@ -59,7 +81,8 @@ Parameters: - `info` : A `NamedTuple` containing miscellaneous information relevant to the chain. The `info` field can be set using `setinfo(c::Chains, n::NamedTuple)`. """ -struct Chains{T,A<:AxisArray{T,3},L,K<:NamedTuple,I<:NamedTuple} <: AbstractMCMC.AbstractChains +struct Chains{T,A<:AxisArray{T,3},L,K<:NamedTuple,I<:NamedTuple} <: + AbstractMCMC.AbstractChains value::A logevidence::L name_map::K diff --git a/src/chains.jl b/src/chains.jl index 877316a0..6e131783 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -3,15 +3,15 @@ ## Constructors ## # Constructor to handle a vector of vectors. -Chains(val::AbstractVector{<:AbstractVector{<:Union{Missing, Real}}}, args...; kwargs...) = - Chains(copy(reduce(hcat, val)'), args...; kwargs...) +Chains(val::AbstractVector{<:AbstractVector{<:Union{Missing,Real}}}, args...; kwargs...) = + Chains(copy(reduce(hcat, val)'), args...; kwargs...) # Constructor to handle a 1D array. -Chains(val::AbstractVector{<:Union{Missing, Real}}, args...; kwargs...) = - Chains(reshape(val, :, 1, 1), args...; kwargs...) +Chains(val::AbstractVector{<:Union{Missing,Real}}, args...; kwargs...) = + Chains(reshape(val, :, 1, 1), args...; kwargs...) # Constructor to handle a 2D array -Chains(val::AbstractMatrix{<:Union{Missing, Real}}, args...; kwargs...) = +Chains(val::AbstractMatrix{<:Union{Missing,Real}}, args...; kwargs...) = Chains(reshape(val, size(val, 1), size(val, 2), 1), args...; kwargs...) # Constructor to handle parameter names that are not Symbols. @@ -19,26 +19,31 @@ function Chains( val::AbstractArray{<:Union{Missing,Real},3}, parameter_names::AbstractVector, args...; - kwargs... + kwargs..., ) return Chains(val, Symbol.(parameter_names), args...; kwargs...) end # Generic chain constructor. function Chains( - val::AbstractArray{<:Union{Missing, Real},3}, + val::AbstractArray{<:Union{Missing,Real},3}, parameter_names::AbstractVector{Symbol} = Symbol.(:param_, 1:size(val, 2)), name_map = (parameters = parameter_names,); start::Int = 1, thin::Int = 1, - iterations::AbstractVector{Int} = range(start; step=thin, length=size(val, 1)), + iterations::AbstractVector{Int} = range(start; step = thin, length = size(val, 1)), evidence = missing, - info::NamedTuple = NamedTuple() + info::NamedTuple = NamedTuple(), ) # Check that iteration numbers are reasonable if length(iterations) != size(val, 1) - error("length of `iterations` (", length(iterations), - ") is not equal to the number of iterations (", size(val, 1), ")") + error( + "length of `iterations` (", + length(iterations), + ") is not equal to the number of iterations (", + size(val, 1), + ")", + ) end if !isempty(iterations) && first(iterations) < 1 error("iteration numbers must be positive integers") @@ -70,10 +75,7 @@ function Chains( append!(_name_map[:parameters], unassigned) # Construct the AxisArray. - arr = AxisArray(val; - iter = iterations, - var = parameter_names, - chain = 1:size(val, 3)) + arr = AxisArray(val; iter = iterations, var = parameter_names, chain = 1:size(val, 3)) # Create the new chain. return Chains(arr, evidence, _name_map, info) @@ -104,10 +106,10 @@ julia> names(chn2) Chains(c::Chains, section::Union{Symbol,AbstractString}) = Chains(c, (section,)) function Chains(chn::Chains, sections) # Make sure the sections exist first. - all(haskey(chn.name_map, Symbol(x)) for x in sections) || - error("some sections are not present in the chain") + all(haskey(chn.name_map, Symbol(x)) for x in sections) || + error("some sections are not present in the chain") - # Create the new section map. + # Create the new section map. name_map = (; (Symbol(k) => chn.name_map[Symbol(k)] for k in sections)...) # Extract wanted values. @@ -153,8 +155,9 @@ julia> namesingroup(chn, :A; index_type=:dot) Symbol("A.2") ``` """ -namesingroup(chains::Chains, sym::AbstractString; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...) -function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket) +namesingroup(chains::Chains, sym::AbstractString; kwargs...) = + namesingroup(chains, Symbol(sym); kwargs...) +function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol = :bracket) if index_type !== :bracket && index_type !== :dot error("index_type must be :bracket or :dot") end @@ -239,16 +242,16 @@ julia> get(chn, :param_1; flatten=true) (param_1 = 1,) ``` """ -function Base.get(c::Chains, vs::Vector{Symbol}; flatten=false) +function Base.get(c::Chains, vs::Vector{Symbol}; flatten = false) pairs = OrderedCollections.OrderedDict() for v in vs syms = namesingroup(c, v) len = length(syms) val = () if len > 1 - val = ntuple(i -> c.value[:,syms[i],:], length(syms)) + val = ntuple(i -> c.value[:, syms[i], :], length(syms)) elseif len == 1 - val = c.value[:,syms[1],:] + val = c.value[:, syms[1], :] else continue end @@ -263,7 +266,7 @@ function Base.get(c::Chains, vs::Vector{Symbol}; flatten=false) end return _dict2namedtuple(pairs) end -Base.get(c::Chains, v::Symbol; flatten=false) = get(c, [v]; flatten=flatten) +Base.get(c::Chains, v::Symbol; flatten = false) = get(c, [v]; flatten = flatten) """ get(c::Chains; section::Union{Symbol,AbstractVector{Symbol}}, flatten=false) @@ -284,11 +287,7 @@ julia> get(chn; section=[:internals]) (a = [1; 2;;],) ``` """ -function Base.get( - c::Chains; - section::Union{Symbol,AbstractVector{Symbol}}, - flatten = false -) +function Base.get(c::Chains; section::Union{Symbol,AbstractVector{Symbol}}, flatten = false) names = OrderedCollections.OrderedSet(Symbol[]) regex = r"[^\[]*" _section = section isa Symbol ? (section,) : section @@ -337,7 +336,7 @@ And data, a 5×1 Matrix{Int64}: 5 ``` """ -get_params(c::Chains; flatten = false) = get(c, section = sections(c), flatten=flatten) +get_params(c::Chains; flatten = false) = get(c, section = sections(c), flatten = flatten) #################### Base Methods #################### @@ -347,13 +346,7 @@ end function Base.show(io::IO, mime::MIME"text/plain", chains::Chains) print(io, "Chains ", chains, ":\n\n", header(chains)) - - # Show summary stats. - summaries = describe(chains) - for summary in summaries - println(io) - show(io, mime, summary) - end + println(io, "\nUse `describe(chains)` for summary statistics and quantiles.") end Base.keys(c::Chains) = names(c) @@ -371,7 +364,7 @@ Base.convert(::Type{Array}, chn::Chains) = convert(Array, chn.value) to_datetime(t::DateTime) = t to_datetime(t::Float64) = unix2datetime(t) to_datetime(t) = missing -to_datetime_vec(t::Union{Float64, DateTime}) = [to_datetime(t)] +to_datetime_vec(t::Union{Float64,DateTime}) = [to_datetime(t)] to_datetime_vec(t::DateTime) = [to_datetime(t)] to_datetime_vec(ts::Vector) = map(to_datetime, ts) to_datetime_vec(ts) = [missing] @@ -424,7 +417,7 @@ Calculate the wall clock time for all chains in seconds. The duration is calculated as `stop - start`, where as default `stop` is the latest stopping time and `start` is the earliest starting time. """ -function wall_duration(c::Chains; start=min_start(c), stop=max_stop(c)) +function wall_duration(c::Chains; start = min_start(c), stop = max_stop(c)) # DateTime - DateTime returns a Millisecond value, # divide by 1k to get seconds. return if start === missing || stop === missing @@ -444,11 +437,7 @@ The duration is calculated as the sum of `start - stop` in seconds. `compute_duration` is more useful in cases of parallel sampling, where `wall_duration` may understate how much computation time was utilitzed. """ -function compute_duration( - c::Chains; - start=start_times(c), - stop=stop_times(c) -) +function compute_duration(c::Chains; start = start_times(c), stop = stop_times(c)) # Calculate total time for each chain, then add it up. if start === missing || stop === missing return missing @@ -480,16 +469,25 @@ The new chain and `chains` share the same data in memory. """ function setrange(chains::Chains, range::AbstractVector{Int}) if length(chains) != length(range) - error("length of `range` (", length(range), - ") is not equal to the number of iterations (", length(chains), ")") + error( + "length of `range` (", + length(range), + ") is not equal to the number of iterations (", + length(chains), + ")", + ) end if !isempty(range) && first(range) < 1 error("iteration numbers must be positive integers") end isstrictlyincreasing(range) || error("iteration numbers must be strictly increasing") - value = AxisArray(chains.value.data; - iter = range, var = names(chains), chain = MCMCChains.chains(chains)) + value = AxisArray( + chains.value.data; + iter = range, + var = names(chains), + chain = MCMCChains.chains(chains), + ) return Chains(value, chains.logevidence, chains.name_map, chains.info) end @@ -525,7 +523,8 @@ Base.names(chains::Chains) = chains.value[Axis{:var}].val Return the parameter names of a `section` in the `chains`. """ -Base.names(chains::Chains, section::Symbol) = convert(Vector{Symbol}, chains.name_map[section]) +Base.names(chains::Chains, section::Symbol) = + convert(Vector{Symbol}, chains.name_map[section]) """ names(chains::Chains, sections) @@ -548,7 +547,8 @@ Return multiple `Chains` objects, each containing only a single section. function get_sections(chains::Chains, sections = keys(chains.name_map)) return [Chains(chains, section) for section in sections] end -get_sections(chains::Chains, section::Union{Symbol, AbstractString}) = Chains(chains, section) +get_sections(chains::Chains, section::Union{Symbol,AbstractString}) = + Chains(chains, section) """ sections(c::Chains) @@ -573,14 +573,14 @@ header(chn) header(chn, section = :parameter) ``` """ -function header(c::Chains; section=missing) +function header(c::Chains; section = missing) rng = range(c) # Function to make section strings. section_str(sec, arr) = string( "$sec", repeat(" ", 18 - length(string(sec))), - "= $(join(map(string, arr), ", "))\n" + "= $(join(map(string, arr), ", "))\n", ) # Get the timing stats @@ -610,20 +610,18 @@ function header(c::Chains; section=missing) "Number of chains = $(size(c, 3))\n", "Samples per chain = $(length(range(c)))\n", ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n", - ismissing(compute) ? "" : "Compute duration = $(round(compute, digits=2)) seconds\n", - section_strings... + ismissing(compute) ? "" : + "Compute duration = $(round(compute, digits=2)) seconds\n", + section_strings..., ) end -function indiscretesupport( - c::Chains, - bounds::Tuple{Real, Real}=(0, Inf) -) +function indiscretesupport(c::Chains, bounds::Tuple{Real,Real} = (0, Inf)) nrows, nvars, nchains = size(c.value) result = Array{Bool}(undef, nvars * (nrows > 0)) - for i in 1:nvars + for i = 1:nvars result[i] = true - for j in 1:nrows, k in 1:nchains + for j = 1:nrows, k = 1:nchains x = c.value[j, i, k] if !isinteger(x) || x < bounds[1] || x > bounds[2] result[i] = false @@ -658,7 +656,7 @@ function Base.sort(c::Chains; lt = NaturalSort.natural) v = c.value x, y, z = size(v) unsorted = collect(zip(1:y, v.axes[2].val)) - sorted = sort(unsorted, by = x -> string(x[2]), lt=lt) + sorted = sort(unsorted, by = x -> string(x[2]), lt = lt) new_axes = (v.axes[1], Axis{:var}([n for (_, n) in sorted]), v.axes[3]) new_v = copy(v.data) for i in eachindex(sorted) @@ -670,7 +668,7 @@ function Base.sort(c::Chains; lt = NaturalSort.natural) # Sort the name map too: namemap = deepcopy(c.name_map) for names in namemap - sort!(names, by=string, lt=lt) + sort!(names, by = string, lt = lt) end return Chains(aa, c.logevidence, namemap, c.info) @@ -717,7 +715,7 @@ function set_section(chains::Chains, namemap) # Assign everything that is missing to :parameters. if !isempty(missingnames) @warn "Section mapping does not contain all parameter names, " * - "$missingnames assigned to :parameters." + "$missingnames assigned to :parameters." for name in missingnames push!(_namemap.parameters, name) end @@ -742,13 +740,13 @@ _clean_sections(::Chains, ::Nothing) = nothing #################### Concatenation #################### Base.cat(c::Chains, cs::Chains...; dims = Val(1)) = _cat(dims, c, cs...) -Base.cat(c::T, cs::T...; dims = Val(1)) where T<:Chains = _cat(dims, c, cs...) +Base.cat(c::T, cs::T...; dims = Val(1)) where {T<:Chains} = _cat(dims, c, cs...) Base.vcat(c::Chains, cs::Chains...) = _cat(Val(1), c, cs...) -Base.vcat(c::T, cs::T...) where T<:Chains = _cat(Val(1), c, cs...) +Base.vcat(c::T, cs::T...) where {T<:Chains} = _cat(Val(1), c, cs...) Base.hcat(c::Chains, cs::Chains...) = _cat(Val(2), c, cs...) -Base.hcat(c::T, cs::T...) where T<:Chains = _cat(Val(2), c, cs...) +Base.hcat(c::T, cs::T...) where {T<:Chains} = _cat(Val(2), c, cs...) AbstractMCMC.chainscat(c::Chains, cs::Chains...) = _cat(Val(3), c, cs...) @@ -768,10 +766,12 @@ function _cat(::Val{1}, c1::Chains, args::Chains...) # concatenate all chains data = mapreduce(c -> c.value.data, vcat, args; init = c1.value.data) - value = AxisArray(data; - iter = mapreduce(range, vcat, args; init=range(c1)), - var = nms, - chain = chns) + value = AxisArray( + data; + iter = mapreduce(range, vcat, args; init = range(c1)), + var = nms, + chain = chns, + ) return Chains(value, missing, c1.name_map, c1.info) end @@ -810,28 +810,30 @@ function _cat(::Val{3}, c1::Chains, args::Chains...) # concatenate all chains data = mapreduce( - c -> c.value.data, - (x, y) -> cat(x, y; dims = 3), - args; - init = c1.value.data + c -> c.value.data, + (x, y) -> cat(x, y; dims = 3), + args; + init = c1.value.data, ) value = AxisArray(data; iter = rng, var = nms, chain = 1:size(data, 3)) # Concatenate times, if available starts = mapreduce( - c -> get(c.info, :start_time, nothing), - vcat, - args, - init = get(c1.info, :start_time, nothing) + c -> get(c.info, :start_time, nothing), + vcat, + args, + init = get(c1.info, :start_time, nothing), ) stops = mapreduce( - c -> get(c.info, :stop_time, nothing), - vcat, - args, - init = get(c1.info, :stop_time, nothing) + c -> get(c.info, :stop_time, nothing), + vcat, + args, + init = get(c1.info, :stop_time, nothing), ) - nontime_props = filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...]) - new_info = NamedTuple{tuple(nontime_props...)}(tuple([c1.info[n] for n in nontime_props]...)) + nontime_props = + filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...]) + new_info = + NamedTuple{tuple(nontime_props...)}(tuple([c1.info[n] for n in nontime_props]...)) new_info = merge(new_info, (start_time = starts, stop_time = stops)) return Chains(value, missing, c1.name_map, new_info) @@ -840,7 +842,7 @@ end function pool_chain(c::Chains) data = c.value.data pool_data = reshape(permutedims(data, [1, 3, 2]), :, size(data, 2), 1) - return Chains(pool_data, names(c), c.name_map; info=c.info) + return Chains(pool_data, names(c), c.name_map; info = c.info) end """ @@ -885,7 +887,9 @@ function replacenames(chains::Chains, old_new::Pair...) value = AxisArray( chains.value.data; - iter = range(chains), var = names_of_params, chain = 1:size(chains, 3) + iter = range(chains), + var = names_of_params, + chain = 1:size(chains, 3), ) return Chains(value, chains.logevidence, namemap, chains.info) diff --git a/src/stats.jl b/src/stats.jl index e31044ea..0bbc85fd 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -20,20 +20,21 @@ function autocor( append_chains = true, demean::Bool = true, lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), - kwargs... + kwargs..., ) funs = Function[] func_names = @. Symbol("lag ", lags) for i in lags - push!(funs, x -> autocor(x, [i], demean=demean)[1]) + push!(funs, x -> autocor(x, [i], demean = demean)[1]) end return summarize( - chains, funs...; + chains, + funs...; func_names = func_names, append_chains = append_chains, name = "Autocorrelation", - kwargs... + kwargs..., ) end @@ -64,7 +65,7 @@ function cor( chains::Chains; sections = _default_sections(chains), append_chains = true, - kwargs... + kwargs..., ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) @@ -77,10 +78,8 @@ function cor( return df else vector_of_df = [ - chaindataframe_cor( - "Correlation - Chain $i", names_of_params, data - ) - for (i, data) in enumerate(to_vector_of_matrices(_chains)) + chaindataframe_cor("Correlation - Chain $i", names_of_params, data) for + (i, data) in enumerate(to_vector_of_matrices(_chains)) ] return vector_of_df end @@ -91,8 +90,10 @@ function chaindataframe_cor(name, names_of_params, chains::AbstractMatrix; kwarg cormat = cor(chains) # Summarize the results in a named tuple. - nt = (; parameters = names_of_params, - zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))...) + nt = (; + parameters = names_of_params, + zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))..., + ) # Create a ChainDataFrame. return ChainDataFrame(name, nt; kwargs...) @@ -110,7 +111,7 @@ function changerate( chains::Chains{<:Real}; sections = _default_sections(chains), append_chains = true, - kwargs... + kwargs..., ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) @@ -123,9 +124,7 @@ function changerate( return df else vector_of_df = [ - chaindataframe_changerate( - "Change Rate - Chain $i", names_of_params, data - ) + chaindataframe_changerate("Change Rate - Chain $i", names_of_params, data) for (i, data) in enumerate(to_vector_of_matrices(_chains)) ] return vector_of_df @@ -150,10 +149,10 @@ function changerate(chains::AbstractArray{<:Real,3}) changerates = zeros(nparams) mvchangerate = 0.0 - for chain in 1:nchains, iter in 2:niters + for chain = 1:nchains, iter = 2:niters isanychanged = false - for param in 1:nparams + for param = 1:nparams # update if the sample is different from the one in the previous iteration if chains[iter-1, param, chain] != chains[iter, param, chain] changerates[param] += 1 @@ -171,35 +170,41 @@ function changerate(chains::AbstractArray{<:Real,3}) changerates, mvchangerate end -describe(c::Chains; args...) = describe(stdout, c; args...) - """ describe(io, chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], etype = :bm, kwargs...]) - -Print the summary statistics and quantiles for the chain. +Print chain metadata, summary statistics, and quantiles. Use `describe(chains)` for REPL output to `stdout`, or specify `io` for other streams (e.g., file output). """ -function describe( +function DataAPI.describe( io::IO, chains::Chains; q = [0.025, 0.25, 0.5, 0.75, 0.975], etype = :bm, - kwargs... + kwargs..., ) - dfs = vcat(summarystats(chains; etype = etype, kwargs...), - quantile(chains; q = q, kwargs...)) - return dfs + print(io, "Chains ", chains, ":\n\n", header(chains)) + + summstats = summarystats(chains; etype = etype, kwargs...) + println(io) + show(io, MIME("text/plain"), summstats) + + qs = quantile(chains; q = q, kwargs...) + println(io) + show(io, MIME("text/plain"), qs) end -function _hpd(x::AbstractVector{<:Real}; alpha::Real=0.05) +# Convenience method for default IO +DataAPI.describe(chains::Chains; kwargs...) = DataAPI.describe(stdout, chains; kwargs...) + +function _hpd(x::AbstractVector{<:Real}; alpha::Real = 0.05) n = length(x) m = max(1, ceil(Int, alpha * n)) y = sort(x) a = y[1:m] - b = y[(n - m + 1):n] + b = y[(n-m+1):n] _, i = findmin(b - a) return [a[i], b[i]] @@ -227,10 +232,10 @@ HPD b 0.0114 0.9460 ``` """ -function hpd(chn::Chains; alpha::Real=0.05, kwargs...) +function hpd(chn::Chains; alpha::Real = 0.05, kwargs...) labels = [:lower, :upper] - l(x) = _hpd(x, alpha=alpha)[1] - u(x) = _hpd(x, alpha=alpha)[2] + l(x) = _hpd(x, alpha = alpha)[1] + u(x) = _hpd(x, alpha = alpha)[2] return summarize(chn, l, u; name = "HPD", func_names = labels, kwargs...) end @@ -246,7 +251,7 @@ function quantile( chains::Chains; q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, - kwargs... + kwargs..., ) # compute quantiles funs = Function[] @@ -256,11 +261,12 @@ function quantile( end return summarize( - chains, funs...; + chains, + funs...; func_names = func_names, append_chains = append_chains, name = "Quantiles", - kwargs... + kwargs..., ) end @@ -290,10 +296,10 @@ function summarystats( autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(), maxlag = 250, name = "Summary Statistics", - kwargs... + kwargs..., ) # Store everything. - funs = [mean∘cskip, std∘cskip] + funs = [mean ∘ cskip, std ∘ cskip] func_names = [:mean, :std] # Subset the chain. @@ -303,30 +309,41 @@ function summarystats( nt_additional = NamedTuple() try mcse_df = MCMCDiagnosticTools.mcse( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, + _chains; + sections = nothing, + autocov_method = autocov_method, + maxlag = maxlag, ) - nt_additional = merge(nt_additional, (; mcse=mcse_df.nt.mcse)) + nt_additional = merge(nt_additional, (; mcse = mcse_df.nt.mcse)) catch e @warn "MCSE calculation failed: $e" end try ess_tail_df = MCMCDiagnosticTools.ess( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail + _chains; + sections = nothing, + autocov_method = autocov_method, + maxlag = maxlag, + kind = :tail, ) - nt_additional = merge(nt_additional, (ess_tail=ess_tail_df.nt.ess,)) + nt_additional = merge(nt_additional, (ess_tail = ess_tail_df.nt.ess,)) catch e @warn "Tail ESS calculation failed: $e" end try ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank + _chains; + sections = nothing, + autocov_method = autocov_method, + maxlag = maxlag, + kind = :rank, ) nt_ess_rhat_rank = ( - ess_bulk=ess_rhat_rank_df.nt.ess, - rhat=ess_rhat_rank_df.nt.rhat, - ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec + ess_bulk = ess_rhat_rank_df.nt.ess, + rhat = ess_rhat_rank_df.nt.rhat, + ess_per_sec = ess_rhat_rank_df.nt.ess_per_sec, ) nt_additional = merge(nt_additional, nt_ess_rhat_rank) catch e @@ -335,16 +352,20 @@ function summarystats( # Possibly re-order the columns to stay backwards-compatible. additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec) - additional_df = ChainDataFrame("Additional", (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...)) + additional_df = ChainDataFrame( + "Additional", + (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...), + ) # Summarize. summary_df = summarize( - _chains, funs...; + _chains, + funs...; func_names, append_chains, additional_df, name, - sections = nothing + sections = nothing, ) return summary_df @@ -357,16 +378,12 @@ Calculate the mean of a chain. """ function mean(chains::Chains; kwargs...) # Store everything. - funs = [mean∘cskip] + funs = [mean ∘ cskip] func_names = [:mean] # Summarize. - summary_df = summarize( - chains, funs...; - func_names = func_names, - name = "Mean", - kwargs... - ) + summary_df = + summarize(chains, funs...; func_names = func_names, name = "Mean", kwargs...) return summary_df end diff --git a/test/Project.toml b/test/Project.toml index 078d447e..4166a6e5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -35,7 +35,7 @@ FFTW = "1.1" IteratorInterfaceExtensions = "1" KernelDensity = "0.6.2" Logging = "<0.0.1, 1" -MCMCChains = "6" +MCMCChains = "7" MCMCDiagnosticTools = "0.3.10" MLJBase = "1" MLJDecisionTreeInterface = "0.4" diff --git a/test/tables_tests.jl b/test/tables_tests.jl index ba3cb71c..2dc99914 100644 --- a/test/tables_tests.jl +++ b/test/tables_tests.jl @@ -18,10 +18,10 @@ using DataFrames @test Tables.columnaccess(typeof(chn)) @test Tables.columns(chn) === chn @test Tables.columnnames(chn) == - (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) @test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000] @test Tables.getcolumn(chn, :chain) == - [fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)] + [fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)] @test Tables.getcolumn(chn, :a) == [ vec(chn[:, :a, 1]) vec(chn[:, :a, 2]) @@ -43,10 +43,10 @@ using DataFrames rows = collect(Tables.rows(chn)) @test eltype(rows) <: Tables.AbstractRow @test size(rows) === (4000,) - for chainid in 1:4, iterid in 1:1000 - row = rows[(chainid - 1) * 1000 + iterid] + for chainid = 1:4, iterid = 1:1000 + row = rows[(chainid-1)*1000+iterid] @test Tables.columnnames(row) == - (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) @test Tables.getcolumn(row, 1) == iterid @test Tables.getcolumn(row, 2) == chainid @test Tables.getcolumn(row, 3) == chn[iterid, :a, chainid] @@ -61,22 +61,25 @@ using DataFrames @testset "integration tests" begin @test length(Tables.rowtable(chn)) == 4000 nt = Tables.rowtable(chn)[1] - @test nt == - (; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...) + @test nt == (; + (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))... + ) @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1] nt = Tables.rowtable(chn)[2] - @test nt == - (; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...) + @test nt == (; + (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))... + ) @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2] @test Tables.matrix(chn[:, :, 1])[:, 3:end] ≈ chn[:, :, 1].value @test Tables.matrix(chn[:, :, 2])[:, 3:end] ≈ chn[:, :, 2].value - @test Tables.matrix(Tables.rowtable(chn)) == Tables.matrix(Tables.columntable(chn)) + @test Tables.matrix(Tables.rowtable(chn)) == + Tables.matrix(Tables.columntable(chn)) end @testset "schema" begin @test Tables.schema(chn) isa Tables.Schema @test Tables.schema(chn).names === - (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) @test Tables.schema(chn).types === ( Int, Int, @@ -134,77 +137,105 @@ using DataFrames colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] internal_colnames = ["c", "d", "e", "f", "g", "h"] chn = Chains(val, colnames, Dict(:internals => internal_colnames)) - cdf = describe(chn)[1] - @testset "Tables interface" begin - @test Tables.istable(typeof(cdf)) - - @testset "column access" begin - @test Tables.columnaccess(typeof(cdf)) - @test Tables.columns(cdf) === cdf - @test Tables.columnnames(cdf) == keys(cdf.nt) - for (k, v) in pairs(cdf.nt) - @test isequal(Tables.getcolumn(cdf, k), v) + # Get ChainDataFrame objects + summstats = summarystats(chn) + qs = quantile(chn) + + # Helper function to test any ChainDataFrame + function test_chaindataframe(cdf::ChainDataFrame) + @testset "Tables interface" begin + @test Tables.istable(typeof(cdf)) + + @testset "column access" begin + @test Tables.columnaccess(typeof(cdf)) + @test Tables.columns(cdf) === cdf + @test Tables.columnnames(cdf) == keys(cdf.nt) + for (k, v) in pairs(cdf.nt) + @test isequal(Tables.getcolumn(cdf, k), v) + end + @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) + @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) + @test_throws Exception Tables.getcolumn(cdf, :blah) + @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) end - @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) - @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) - @test_throws Exception Tables.getcolumn(cdf, :blah) - @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) - end - @testset "row access" begin - @test Tables.rowaccess(typeof(cdf)) - @test Tables.rows(cdf) isa Tables.RowIterator - @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow - rows = collect(Tables.rows(cdf)) - @test eltype(rows) <: Tables.AbstractRow - @test size(rows) === (2,) - @testset for i in 1:2 - row = rows[i] - @test Tables.columnnames(row) == keys(cdf.nt) - for j in length(cdf.nt) - @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i]) - @test isequal(Tables.getcolumn(row, keys(cdf.nt)[j]), cdf.nt[j][i]) + @testset "row access" begin + @test Tables.rowaccess(typeof(cdf)) + @test Tables.rows(cdf) isa Tables.RowIterator + @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow + rows = collect(Tables.rows(cdf)) + @test eltype(rows) <: Tables.AbstractRow + @test size(rows) === (2,) + @testset for i = 1:2 + row = rows[i] + @test Tables.columnnames(row) == keys(cdf.nt) + for j in length(cdf.nt) + @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i]) + @test isequal( + Tables.getcolumn(row, keys(cdf.nt)[j]), + cdf.nt[j][i], + ) + end end end + + @testset "integration tests" begin + @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) + @test isequal(Tables.columntable(cdf), cdf.nt) + nt = Tables.rowtable(cdf)[1] + @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) + @test isequal( + nt, + collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1], + ) + nt = Tables.rowtable(cdf)[2] + @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) + @test isequal( + nt, + collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2], + ) + @test isequal( + Tables.matrix(Tables.rowtable(cdf)), + Tables.matrix(Tables.columntable(cdf)), + ) + end + + @testset "schema" begin + schema = Tables.schema(cdf) + @test schema isa Tables.Schema + @test schema.names == keys(cdf.nt) + @test schema.types == eltype.(values(cdf.nt)) + end end - @testset "integration tests" begin - @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) - @test isequal(Tables.columntable(cdf), cdf.nt) - nt = Tables.rowtable(cdf)[1] + @testset "TableTraits interface" begin + @test IteratorInterfaceExtensions.isiterable(cdf) + @test TableTraits.isiterabletable(cdf) + nt = collect( + Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1), + )[1] @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]) - nt = Tables.rowtable(cdf)[2] + nt = collect( + Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2), + )[2] @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) - @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]) - @test isequal( - Tables.matrix(Tables.rowtable(cdf)), - Tables.matrix(Tables.columntable(cdf)), - ) end - @testset "schema" begin - @test Tables.schema(cdf) isa Tables.Schema - @test Tables.schema(cdf).names === keys(cdf.nt) - @test Tables.schema(cdf).types === eltype.(values(cdf.nt)) + @testset "DataFrames.DataFrame constructor" begin + @inferred DataFrame(cdf) + df = DataFrame(cdf) + @test df isa DataFrame + @test isequal(Tables.columntable(df), cdf.nt) end end - @testset "TableTraits interface" begin - @test IteratorInterfaceExtensions.isiterable(cdf) - @test TableTraits.isiterabletable(cdf) - nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1))[1] - @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2))[2] - @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) + @testset "Summary Statistics" begin + test_chaindataframe(summstats) end - @testset "DataFrames.DataFrame constructor" begin - @inferred DataFrame(cdf) - df = DataFrame(cdf) - @test df isa DataFrame - @test isequal(Tables.columntable(df), cdf.nt) + @testset "Quantiles" begin + test_chaindataframe(qs) end end end