Skip to content

Gibbs Compilation Performance #2542

Open
@joelkandiah

Description

@joelkandiah

The runtime of the compilation of the Gibbs sampler appears to be O(N_component_samplers), however some tests show that the cost to compile any given call to sample is high, and the cost of each individual component sampler is high.

I have attached a test file, which allows a user to pass in an argument for the number of components to use (between 3 and 9) where 17 Normal(0,1) variables are sampled. These are sampled using a default Metropolis Hastings sampler (i.e. from the prior) and no data is provided. We run this for two samples to allow the code to compile the sample call with and without varinfo being passed.

(Bad practice: On my local machine) This estimates a runtime of ~25-30 seconds per extra component sampler, where for 3 components the time taken is ~150 seconds and for 8 components the time taken is ~290 seconds.
This is significantly longer than using just an MH sampler without Gibbs ~10 seconds.

using Turing # v0.37.0
using Distributions

# --- Probabilistic Model Definition ---
@model function test_gibbs_individual()
    # Define the model parameters with Normal priors
    p1 ~ Normal(0, 1)
    p2 ~ Normal(0, 1)
    p3 ~ Normal(0, 1)
    p4 ~ Normal(0, 1)
    p5 ~ Normal(0, 1)
    p6 ~ Normal(0, 1)
    p7 ~ Normal(0, 1)
    p8 ~ Normal(0, 1)

    vp1 ~ Normal(0, 1)
    vp2 ~ Normal(0, 1)
    vp3 ~ Normal(0, 1)
    vp4 ~ Normal(0, 1)
    vp5 ~ Normal(0, 1)
    vp6 ~ Normal(0, 1)
    vp7 ~ Normal(0, 1)
    vp8 ~ Normal(0, 1)

    global_p1 ~ Normal(0, 1)

    return
end


# --- Gibbs Construction --- #
# Define a simple gibbs construction for 9 compartments (dynamic construction with the same number of parameters can be found in the attached .txt file
gibbs_MH = Gibbs(
    (@varname(global_p1),) => MH(),
    (@varname(p1), @varname(vp1)) => MH(),
    (@varname(p2), @varname(vp2)) => MH(),
    (@varname(p3), @varname(vp3)) => MH(),
    (@varname(p4), @varname(vp4)) => MH(),
    (@varname(p5), @varname(vp5)) => MH(),
    (@varname(p6), @varname(vp6)) => MH(),
    (@varname(p7), @varname(vp7)) => MH(),
    (@varname(p8), @varname(vp8)) => MH()
)

# --- Time compilation --- #
# Call sample and use this to estimate the compilation time (model evals should be fast
time1 = time_ns()
sample(test_gibbs_individual(), gibbs_MH, MCMCSerial(), 2, 1; progress=false) # progress=false to avoid output interfering with timing
time2 = time_ns()

# --- Output Results --- #
println("n_components: 9")
println("Approximate Compilation Time: $(time2 - time1) ns")
println("Approximate Compilation Time: $((time2 - time1)/1e9) s")

The code has been attached as a .txt file

main.txt

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions