Skip to content

New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

New libtask interface #114

wants to merge 15 commits into from

Conversation

FredericWantiez
Copy link
Member

@FredericWantiez FredericWantiez commented Mar 23, 2025

Integrate refactor from TuringLang/Libtask.jl#179

Two things worth noting:

  1. Dealing with the RNG will be the user's responsability. Before
mutable struct Model <: AdvancedPS.AbstractGenericModel
  mu::Float64
  sig::Float64

  Model() = new()
end


function (model::Model)(rng::Random.AbstractRNG)
  model.sig = rand(rng, Beta(1, 1))  # AdvancedPS took care of syncing these
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

and now:

function (model::Model)()
  rng = Libtask.get_dynamic_scope() # We now need to query the RNG explicitly
  model.sig = rand(rng, Beta(1, 1))
  Libtask.produce(model.sig)

  rng = Libtask.get_dynamic_scope() # and do it everytime we want to sample random values
  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end
  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs
    function AdvancedPS.forkr(trace::LibtaskTrace)
    newf = AdvancedPS.reset_model(trace.model.ctask.fargs[1])
    Random123.set_counter!(trace.rng, 1)

@willtebbutt

@FredericWantiez FredericWantiez changed the title New libtask interface [WIP] New libtask interface Mar 23, 2025
@willtebbutt
Copy link
Member

Thanks for having a look at this!

  1. Dealing with the RNG will be the user's responsability. Before

Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)

  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs

I agree re not wanting ot dig into tapedtask.fargs. Could you elaborate a little bit on what is required here? My understanding was that task copying would handle this -- i.e. when you copy a task, all references to the model get updated, so from the perspective of the code inside the task, things just continue as normal.

As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Mar 25, 2025

  1. We can drop this one, that really only applies when AdvancedPS is used with Libtask outside of Turing. We will probably sunset that (or target people who supposedly know enough about Libtask)

  2. Still not 100% sure about Turing but we need something like this to manage the reference particle in the Particle Gibbs loop. Here's a mvp that should replicate a simple loop of the algo:

using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems


mutable struct Model <: AdvancedPS.AbstractGenericModel
  x::Float64
  y::Float64

  Model() = new()
end


function (model::Model)()
  rng = Libtask.get_dynamic_scope()
  model.x = rand(rng, Beta(1,1))
  Libtask.produce(model.x)

  rng = Libtask.get_dynamic_scope()
  model.y = rand(rng, Normal(0, model.x))
  Libtask.produce(model.y)
end

rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()

trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)

trace2 = AdvancedPS.fork(trace)

key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)

Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])

# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)

println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])

# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory, 
# sampled values are properly captured during the execution of the task

@yebai
Copy link
Member

yebai commented Mar 26, 2025

println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the Model
Note, this is only a problem when creating a reference trajectory,

@FredericWantiez can we store trace.rng inside TapedTask instead of trace? That way, when copying a TapedTask, we will copy the trace.rng.

@FredericWantiez
Copy link
Member Author

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 5, 2025

Two small issues I found cleaning up the tests.

Libtask returns a value after the last produce statement:

function f()
  Libtask.produce(1)
  Libtask.produce(2)
end

t1 = TapedTask(nothing, f)
consume(t1)  # 1
consume(t1)  # 2
consume(t2)  # 2 (?)

Libtask doesn't catch some of the produce statements:

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
  a::Float64
  b::Float64

  NormalModel() = new()
end

function (m::NormalModel)()
  # First latent variable.
  rng = Libtask.get_dynamic_scope()
  m.a = a = rand(rng, Normal(4, 5))

  # First observation.
  AdvancedPS.observe(Normal(a, 2), 3)

  # Second latent variable.
  rng = Libtask.get_dynamic_scope()
  m.b = b = rand(rng, Normal(a, 1))

  # Second observation.
  AdvancedPS.observe(Normal(b, 2), 1.5)
  return nothing
end

rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())

consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?)

this works fine if I call Libtask.produce explicitly instead of observe

EDIT: Changing observe to something like this seems to work:

function AdvancedPS.observe(dist::Distributions.Distribution, x)
    Libtask.produce(Distributions.loglikelihood(dist, x))
    return nothing
end

@yebai
Copy link
Member

yebai commented Apr 8, 2025

If we store both rng and varinfo in the scoped variable, then the following suggestions will address (2):

  • store varinfo in the Trace struct, then change here to Libtask.set_dynamic_scope!(trace.model.ctask, (trace.rng, trace.varinfo))
  • change here and here to rng, varinfo = Libtask.get_dynamic_scope()
  • change here to transition = SMCTransition(model, particle.varinfo, weight)

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The other solution is to use one replay step before the transition, to repopulate the varinfo properly:

    new_particle = AdvancedPS.replay(particle)
    transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
    state = SMCState(particles, 2, logevidence)
    return transition, state

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

@willtebbutt running models against this PR I see a large performance drop:

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    rng = Libtask.get_dynamic_scope()
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.get_dynamic_scope()
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

On master:

1.816623 seconds (5.92 M allocations: 311.647 MiB, 1.52% gc time, 96.09% compilation time)

On this PR:

72.085056 seconds (369.62 M allocations: 17.322 GiB, 2.83% gc time, 77.21% compilation time)

@willtebbutt
Copy link
Member

Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance.

@yebai
Copy link
Member

yebai commented Apr 9, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The varinfo variable is updated during inference. I think we have to carefully ensure the correct varinfo is stored in the scoped variable.

cc @mhauru @FredericWantiez

@willtebbutt
Copy link
Member

@willtebbutt running models against this PR I see a large performance drop:

@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is

(jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
  [576499cb] AdvancedPS v0.6.1
  [31c24e10] Distributions v0.25.118
  [6f1fad26] Libtask v0.8.8
  [9a3f8284] Random v1.11.0

I tried it on LTS and 1.11.4.

In particular, I'm seeing the error:

ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
  [1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
    @ Base ./essentials.jl:14
  [2] getindex
    @ ./essentials.jl:916 [inlined]
  [3] _infer(f::NormalModel, args_type::Tuple{DataType})
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
  [4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
  [5] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
  [6] _
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [7] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [8] #TapedTask#15
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
  [9] TapedTask
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
 [10] LibtaskModel
    @ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
 [11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
 [12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
 [13] iterate
    @ ./generator.jl:48 [inlined]
 [14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [15] collect_similar
    @ ./array.jl:720 [inlined]
 [16] map
    @ ./abstractarray.jl:3371 [inlined]
 [17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
 [18] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
 [20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
 [21] mcmcsample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
 [22] #sample#20
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
 [23] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
 [24] #sample#19
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
 [25] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
 [26] macro expansion
    @ ./timing.jl:581 [inlined]
 [27] top-level scope
    @ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.

Any idea whether I'm doing something wrong?

@willtebbutt
Copy link
Member

But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to Libtask.get_dynamic_scope, which should be the type of the thing that it's going to return. We need this because there's no way to make the container typed (I assume that the previous implementation had a similar limitation). The docstring has been updated to reflect the changes.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 15, 2025

@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that:

function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
  model.sig = rand(rng, Beta(1, 1)) 
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

@yebai
Copy link
Member

yebai commented Apr 22, 2025

This now runs faster with AdvancedPS (dc5e594) and Libtask (8e7f784)

# run once to triger compilation 
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  2.986750 seconds (7.31 M allocations: 380.449 MiB, 0.88% gc time, 99.51% compilation time)

# second time runs faster
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  0.012714 seconds (32.85 k allocations: 18.581 MiB, 19.87% gc time)
Code
(@temp) pkg> add AdvancedPS#fred/libtask-revamp
(@temp) pkg> add Libtask#wct/refactor

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    T = AdvancedPS.TracedRNG{UInt64, 1, AdvancedPS.Random123.Philox2x{UInt64, 10}}; 
    rng = Libtask.Libtask.get_taped_globals(T)
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.Libtask.get_taped_globals(T)
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

@yebai yebai changed the title [WIP] New libtask interface New libtask interface May 9, 2025
@yebai
Copy link
Member

yebai commented May 9, 2025

New Libtask has now been released. CI failures reveal that a few more fixes are required from AdvancedPS.

Copy link
Contributor

github-actions bot commented May 9, 2025

AdvancedPS.jl documentation for PR #114 is available at:
https://TuringLang.github.io/AdvancedPS.jl/previews/PR114/

@mhauru
Copy link
Member

mhauru commented Jun 2, 2025

Most of the test failures seemed to be a case renaming a function (although the opaque error message "Unbound GlobalRef not allowed in value position" that they were yielding concerns me somewhat).

The remaining test failure is about addreference!/current_trace, which I'm a bit confused about. I thought we only wanted to store in the task local storage/taped_globals the RNG. However, addreference! stores the whole trace object. In Turing.jl this is used to get the RNG from the trace (something we can replace with getting it from taped_globals), but also to get a VarInfo from the trace. See these lines in Turing.jl's mcmc/particle_mcmc.jl:

function trace_local_varinfo_maybe(varinfo)
    try
        trace = AdvancedPS.current_trace()
        return trace.model.f.varinfo
    catch e
        # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
        if e == KeyError(:__trace) || current_task().storage isa Nothing
            return varinfo
        else
            rethrow(e)
        end
    end
end

and

function DynamicPPL.assume(
    rng,
    spl::Sampler{<:Union{PG,SMC}},
    dist::Distribution,
    vn::VarName,
    _vi::AbstractVarInfo,
)
    vi = trace_local_varinfo_maybe(_vi)
    [...]

Why do we need to do this? If we do need to do this, do we have to rework our use of taped_globals to store not only an RNG but also a VarInfo?

@yebai
Copy link
Member

yebai commented Jun 2, 2025

Briefly, addreference!/current_trace can be safely removed. These are replaced by set_taped_globals/ get_taped_globals.

(A full explanation requires me explaining the history of Libtask...)

EDIT: (rng, trace.model.f.varinfo) can be saved to TapedTask directly using set_taped_globals. This replaces the old design of storing them in task local storage and keeping a reference to task in each TapedTask (i.e. addreference!)

taped_globals = trace.model.ctask.taped_globals
new_rng = deepcopy(taped_globals.rng)
trace.rng = new_rng
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
Libtask.set_taped_globals!(
trace.model.ctask, TapedGlobals(new_rng, taped_globals.other)
)

@mhauru
Copy link
Member

mhauru commented Jun 6, 2025

I've just started pushing onto your PR @FredericWantiez, I hope you don't mind. I can make a separate PR if you prefer.

Tests should now pass, but please don't review or merge yet. I don't trust that I've done this right until I see it work with Turing.jl. I'll try to get that done locally now.

Copy link

codecov bot commented Jun 6, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@1ad89ec). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #114   +/-   ##
=======================================
  Coverage        ?   96.47%           
=======================================
  Files           ?        8           
  Lines           ?      426           
  Branches        ?        0           
=======================================
  Hits            ?      411           
  Misses          ?       15           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@@ -86,17 +97,18 @@ end
# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
taped_globals = trace.model.ctask.taped_globals
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhauru can we make a Libtask API for this line to avoid depending on Libtask.TapedTask internals. One possibility is to define:

Libtask.get_taped_globals!(T::Any, task::TapedTask)

which returns taped_globals for a given task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants