Description
The runtime of the compilation of the Gibbs sampler appears to be O(N_component_samplers), however some tests show that the cost to compile any given call to sample is high, and the cost of each individual component sampler is high.
I have attached a test file, which allows a user to pass in an argument for the number of components to use (between 3 and 9) where 17 Normal(0,1) variables are sampled. These are sampled using a default Metropolis Hastings sampler (i.e. from the prior) and no data is provided. We run this for two samples to allow the code to compile the sample call with and without varinfo being passed.
(Bad practice: On my local machine) This estimates a runtime of ~25-30 seconds per extra component sampler, where for 3 components the time taken is ~150 seconds and for 8 components the time taken is ~290 seconds.
This is significantly longer than using just an MH sampler without Gibbs ~10 seconds.
using Turing # v0.37.0
using Distributions
# --- Probabilistic Model Definition ---
@model function test_gibbs_individual()
# Define the model parameters with Normal priors
p1 ~ Normal(0, 1)
p2 ~ Normal(0, 1)
p3 ~ Normal(0, 1)
p4 ~ Normal(0, 1)
p5 ~ Normal(0, 1)
p6 ~ Normal(0, 1)
p7 ~ Normal(0, 1)
p8 ~ Normal(0, 1)
vp1 ~ Normal(0, 1)
vp2 ~ Normal(0, 1)
vp3 ~ Normal(0, 1)
vp4 ~ Normal(0, 1)
vp5 ~ Normal(0, 1)
vp6 ~ Normal(0, 1)
vp7 ~ Normal(0, 1)
vp8 ~ Normal(0, 1)
global_p1 ~ Normal(0, 1)
return
end
# --- Gibbs Construction --- #
# Define a simple gibbs construction for 9 compartments (dynamic construction with the same number of parameters can be found in the attached .txt file
gibbs_MH = Gibbs(
(@varname(global_p1),) => MH(),
(@varname(p1), @varname(vp1)) => MH(),
(@varname(p2), @varname(vp2)) => MH(),
(@varname(p3), @varname(vp3)) => MH(),
(@varname(p4), @varname(vp4)) => MH(),
(@varname(p5), @varname(vp5)) => MH(),
(@varname(p6), @varname(vp6)) => MH(),
(@varname(p7), @varname(vp7)) => MH(),
(@varname(p8), @varname(vp8)) => MH()
)
# --- Time compilation --- #
# Call sample and use this to estimate the compilation time (model evals should be fast
time1 = time_ns()
sample(test_gibbs_individual(), gibbs_MH, MCMCSerial(), 2, 1; progress=false) # progress=false to avoid output interfering with timing
time2 = time_ns()
# --- Output Results --- #
println("n_components: 9")
println("Approximate Compilation Time: $(time2 - time1) ns")
println("Approximate Compilation Time: $((time2 - time1)/1e9) s")
The code has been attached as a .txt file