Summary
pm.sample(callback=...) currently raises a ValueError when using nuts_sampler="nutpie" (and other external NUTS samplers). We already pass a callback to nutpie for the progress bar (progress_callback=pb_manager.update), so we could multiplex the user's callback alongside it.
This would let users hook into the sampling loop for custom logging, early stopping, or live monitoring even when using nutpie.
Proposed approach
- Remove the
ValueError for nutpie when callback is provided.
- Wrap both
pb_manager.update and the user callback into a single function passed as progress_callback to nutpie.sample.
- When
progressbar=False, the progress bar update is already a no-op, so only the user callback fires.
Limitations from nutpie
nutpie's progress_callback receives a list[nutpie.ChainProgress] — not the (trace, Draw) pair that PyMC's callback protocol expects. ChainProgress exposes only aggregate stats per chain:
finished_draws, total_draws
divergences, divergent_draws
tuning, started
latest_num_steps, total_num_steps
step_size
runtime_ms
Notably, there is no point (parameter values) or trace access. The callback also fires periodically on a background thread (at progress_rate ms intervals), not once per draw.
This means the user callback for nutpie would have a different signature and semantics than the PyMC sampler callback. We'd need to either:
- Accept this mismatch and document it, or
- Define a separate callback protocol for external samplers
Related
The same ValueError applies to numpyro and blackjax (#7426 tracks progress bar issues with numpyro). Those backends don't have a progress_callback mechanism like nutpie, so supporting user callbacks there would require a different approach.
Summary
pm.sample(callback=...)currently raises aValueErrorwhen usingnuts_sampler="nutpie"(and other external NUTS samplers). We already pass a callback to nutpie for the progress bar (progress_callback=pb_manager.update), so we could multiplex the user's callback alongside it.This would let users hook into the sampling loop for custom logging, early stopping, or live monitoring even when using nutpie.
Proposed approach
ValueErrorfor nutpie whencallbackis provided.pb_manager.updateand the user callback into a single function passed asprogress_callbacktonutpie.sample.progressbar=False, the progress bar update is already a no-op, so only the user callback fires.Limitations from nutpie
nutpie's
progress_callbackreceives alist[nutpie.ChainProgress]— not the(trace, Draw)pair that PyMC's callback protocol expects.ChainProgressexposes only aggregate stats per chain:finished_draws,total_drawsdivergences,divergent_drawstuning,startedlatest_num_steps,total_num_stepsstep_sizeruntime_msNotably, there is no
point(parameter values) ortraceaccess. The callback also fires periodically on a background thread (atprogress_ratems intervals), not once per draw.This means the user callback for nutpie would have a different signature and semantics than the PyMC sampler callback. We'd need to either:
Related
The same
ValueErrorapplies to numpyro and blackjax (#7426 tracks progress bar issues with numpyro). Those backends don't have aprogress_callbackmechanism like nutpie, so supporting user callbacks there would require a different approach.