diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index d3f7620882..eaa484a13f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -87,7 +87,7 @@ RunType: TypeAlias = Run HAS_MCB = True except ImportError: - TraceOrBackend = BaseTrace # type: ignore[misc] + TraceOrBackend = BaseTrace # type: ignore[assignment, misc] RunType = type(None) # type: ignore[assignment, misc] diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 6fb80284fd..bd2425c9f5 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -36,8 +36,6 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs from pytensor.graph.basic import Variable -from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol @@ -67,7 +65,8 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - CustomProgress, + ProgressBarManager, + ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, @@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: else: varnames = ", ".join( [ - get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name + get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[misc] for v in s.vars ] ) @@ -424,7 +423,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -456,7 +455,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -488,8 +487,8 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, - progressbar_theme: Theme | None = default_progress_theme, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -539,11 +538,18 @@ def sample( A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed. We no longer support ``RandomState`` objects because their seeding mechanism does not allow easy spawning of new independent random streams that are needed by the step methods. - progressbar : bool, optional default=True - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - Only applicable to the pymc nuts sampler. + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step @@ -709,6 +715,10 @@ def sample( if isinstance(trace, list): raise ValueError("Please use `var_names` keyword argument for partial traces.") + # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and + # ADVI initialization expect just a bool. + progress_bool = bool(progressbar) + model = modelcontext(model) if not model.free_RVs: raise SamplingError( @@ -805,7 +815,7 @@ def joined_blas_limiter(): initvals=initvals, model=model, var_names=var_names, - progressbar=progressbar, + progressbar=progress_bool, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, nuts_sampler_kwargs=nuts_sampler_kwargs, @@ -824,7 +834,7 @@ def joined_blas_limiter(): n_init=n_init, model=model, random_seed=random_seed_list, - progressbar=progressbar, + progressbar=progress_bool, jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, @@ -1138,25 +1148,35 @@ def _sample_many( Step function """ initial_step_state = step.sampling_state - for i in range(chains): - step.sampling_state = initial_step_state - _sample( - draws=draws, - chain=i, - start=start[i], - step=step, - trace=traces[i], - rng=rngs[i], - callback=callback, - **kwargs, - ) + progress_manager = ProgressBarManager( + step_method=step, + chains=chains, + draws=draws - kwargs.get("tune", 0), + tune=kwargs.get("tune", 0), + progressbar=kwargs.get("progressbar", True), + progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme), + ) + + with progress_manager: + for i in range(chains): + step.sampling_state = initial_step_state + _sample( + draws=draws, + chain=i, + start=start[i], + step=step, + trace=traces[i], + rng=rngs[i], + callback=callback, + progress_manager=progress_manager, + **kwargs, + ) return def _sample( *, chain: int, - progressbar: bool, rng: np.random.Generator, start: PointType, draws: int, @@ -1164,8 +1184,8 @@ def _sample( trace: IBaseTrace, tune: int, model: Model | None = None, - progressbar_theme: Theme | None = default_progress_theme, callback=None, + progress_manager: ProgressBarManager, **kwargs, ) -> None: """Sample one chain (singleprocess). @@ -1176,27 +1196,23 @@ def _sample( ---------- chain : int Number of the chain that the samples will belong to. - progressbar : bool - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - random_seed : single random seed + random_seed : Generator + Single random seed start : dict Starting point in parameter space (or partial point) draws : int The number of samples to draw - step : function - Step function + step : Step + Step class instance used to generate samples. trace A chain backend to record draws and stats. tune : int Number of iterations to tune. - model : Model (optional if in ``with`` context) - progressbar_theme : Theme - Optional custom theme for the progress bar. + model : Model, optional + PyMC model. If None, the model is taken from the current context. + progress_manager: ProgressBarManager + Helper class used to handle progress bar styling and updates """ - skip_first = kwargs.get("skip_first", 0) - sampling_gen = _iter_sample( draws=draws, step=step, @@ -1208,32 +1224,19 @@ def _sample( rng=rng, callback=callback, ) - _pbar_data = {"chain": chain, "divergences": 0} - _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - - progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), - console=Console(theme=progressbar_theme), - disable=not progressbar, - ) + try: + for it, stats in enumerate(sampling_gen): + progress_manager.update( + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + ) - with progress: - try: - task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) - for it, diverging in enumerate(sampling_gen): - if it >= skip_first and diverging: - _pbar_data["divergences"] += 1 - progress.update(task, description=_desc.format(**_pbar_data), completed=it) - progress.update( - task, description=_desc.format(**_pbar_data), completed=draws, refresh=True + if not progress_manager.combined_progress or chain == progress_manager.chains - 1: + progress_manager.update( + chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False ) - except KeyboardInterrupt: - pass + + except KeyboardInterrupt: + pass def _iter_sample( @@ -1247,7 +1250,7 @@ def _iter_sample( rng: np.random.Generator, model: Model | None = None, callback: SamplingIteratorCallback | None = None, -) -> Iterator[bool]: +) -> Iterator[list[dict[str, Any]]]: """Sample one chain with a generator (singleprocess). Parameters @@ -1270,8 +1273,8 @@ def _iter_sample( Yields ------ - diverging : bool - Indicates if the draw is divergent. Only available with some samplers. + stats : list of dict + Dictionary of statistics returned by step sampler """ draws = int(draws) @@ -1293,22 +1296,25 @@ def _iter_sample( step.iter_count = 0 if i == tune: step.stop_tuning() + point, stats = step.step(point) trace.record(point, stats) log_warning_stats(stats) - diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") is True) + if callback is not None: callback( trace=trace, draw=Draw(chain, i == draws, i, i < tune, stats, point), ) - yield diverging + yield stats + except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) trace.close() raise + else: if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 3c2a8c9a36..af2106ce6f 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,8 +27,6 @@ import cloudpickle import numpy as np -from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits @@ -36,7 +34,7 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( - CustomProgress, + ProgressBarManager, RandomGeneratorState, default_progress_theme, get_state_from_generator, @@ -485,23 +483,14 @@ def __init__( self._max_active = cores self._in_context = False - - self._progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), - console=Console(theme=progressbar_theme), - disable=not progressbar, + self._progress = ProgressBarManager( + step_method=step_method, + chains=chains, + draws=draws, + tune=tune, + progressbar=progressbar, + progressbar_theme=progressbar_theme, ) - self._show_progress = progressbar - self._divergences = 0 - self._completed_draws = 0 - self._total_draws = chains * (draws + tune) - self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" - self._chains = chains def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -516,24 +505,13 @@ def __iter__(self): raise ValueError("Use ParallelSampler as context manager.") self._make_active() - with self._progress as progress: - task = progress.add_task( - self._desc.format(self), - completed=self._completed_draws, - total=self._total_draws, - ) - + with self._progress: while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw - self._completed_draws += 1 - if not tuning and stats and stats[0].get("diverging"): - self._divergences += 1 - progress.update( - task, - completed=self._completed_draws, - total=self._total_draws, - description=self._desc.format(self), + + self._progress.update( + chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats ) if is_last: @@ -541,7 +519,6 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index ff3f9c66a5..d07b070f0f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -181,6 +181,20 @@ def __new__(cls, *args, **kwargs): step.__newargs = (vars, *args), kwargs return step + @staticmethod + def _progressbar_config(n_chains=1): + columns = [] + stats = {} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + return stats + + return update_stats + # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): return self.__newargs @@ -297,6 +311,38 @@ def set_rng(self, rng: RandomGenerator): for method, _rng in zip(self.methods, _rngs): method.set_rng(_rng) + def _progressbar_config(self, n_chains=1): + from functools import reduce + + column_lists, stat_dict_list = zip( + *[method._progressbar_config(n_chains) for method in self.methods] + ) + flat_list = reduce(lambda left_list, right_list: left_list + right_list, column_lists) + + columns = [] + headers = [] + + for col in flat_list: + name = col.get_table_column().header + if name not in headers: + headers.append(name) + columns.append(col) + + stats = reduce(lambda left_dict, right_dict: left_dict | right_dict, stat_dict_list) + + return columns, stats + + def _make_update_stats_function(self): + update_fns = [method._make_update_stats_function() for method in self.methods] + + def update_stats(stats, step_stats, chain_idx): + for step_stat, update_fn in zip(step_stats, update_fns): + stats = update_fn(stats, step_stat, chain_idx) + + return stats + + return update_stats + def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: """Flatten a hierarchy of step methods to a list.""" diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index bbda728e80..18707c3592 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -20,6 +20,8 @@ import numpy as np from pytensor import config +from rich.progress import TextColumn +from rich.table import Column from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence @@ -229,6 +231,37 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), + TextColumn("{task.fields[tree_size]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "step_size": [0] * n_chains, + "tree_size": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + if not step_stats["tune"]: + stats["divergences"][chain_idx] += step_stats["diverging"] + + stats["step_size"][chain_idx] = step_stats["step_size"] + stats["tree_size"][chain_idx] = step_stats["tree_size"] + return stats + + return update_stats + # A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory") diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8e22218a13..70c650653d 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -24,6 +24,8 @@ from pytensor import tensor as pt from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV +from rich.progress import TextColumn +from rich.table import Column import pymc as pm @@ -325,6 +327,38 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: def competence(var, has_grad): return Competence.COMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), + TextColumn( + "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) + ), + ] + + stats = { + "tune": [True] * n_chains, + "scaling": [0] * n_chains, + "accept_rate": [0.0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["accept_rate"][chain_idx] = step_stats["accept"] + stats["scaling"][chain_idx] = step_stats["scaling"] + + return stats + + return update_stats + def tune(scale, acc_rate): """ diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index ecc7967614..9c10acfdf4 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -17,6 +17,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.blocking import RaveledVars, StatsType from pymc.initial_point import PointType from pymc.model import modelcontext @@ -195,3 +198,29 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.COMPATIBLE return Competence.INCOMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), + TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), + ] + + stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["nstep_out"][chain_idx] = step_stats["nstep_out"] + stats["nstep_in"][chain_idx] = step_stats["nstep_in"] + + return stats + + return update_stats diff --git a/pymc/util.py b/pymc/util.py index 8dc7d16804..979b3beebf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -17,9 +17,9 @@ import warnings from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import NewType, cast +from typing import TYPE_CHECKING, Literal, NewType, cast import arviz import cloudpickle @@ -30,11 +30,35 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad -from rich.progress import Progress +from rich.box import SIMPLE_HEAD +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.style import Style +from rich.table import Column, Table from rich.theme import Theme from pymc.exceptions import BlockModelAccessError +if TYPE_CHECKING: + from pymc.step_methods.compound import BlockedStep, CompoundStep + + +ProgressBarType = Literal[ + "combined", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", +] + def __getattr__(name): if name == "dataset_to_point_list": @@ -55,6 +79,8 @@ def __getattr__(name): { "bar.complete": "#1764f4", "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", } ) @@ -556,8 +582,10 @@ class CustomProgress(Progress): it's `True`. """ - def __init__(self, *args, **kwargs): - self.is_enabled = kwargs.get("disable", None) is not True + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + if self.is_enabled: super().__init__(*args, **kwargs) @@ -607,6 +635,318 @@ def update( ) return None + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class DivergenceBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a divergence.""" + + def __init__(self, *args, diverging_color="red", **kwargs): + from matplotlib.colors import to_rgb + + self.diverging_color = diverging_color + self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] + + super().__init__(*args, **kwargs) + + self.non_diverging_style = self.complete_style + self.non_diverging_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + divergences = task.fields.get("divergences", 0) + if isinstance(divergences, float | int) and divergences > 0: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + else: + self.complete_style = self.non_diverging_style + self.finished_style = self.non_diverging_finished_style + + +class ProgressBarManager: + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: "BlockedStep | CompoundStep", + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep or CompoundStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + + match progressbar: + case True: + self.combined_progress = False + self.full_stats = True + show_progress = True + case False: + self.combined_progress = False + self.full_stats = True + show_progress = False + case "combined": + self.combined_progress = True + self.full_stats = False + show_progress = True + case "split": + self.combined_progress = False + self.full_stats = False + show_progress = True + case "combined+stats" | "stats+combined": + self.combined_progress = True + self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": + self.combined_progress = False + self.full_stats = True + show_progress = True + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "one of 'combined', 'split', 'split+stats', or 'combined+stats." + ) + + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + + self.progress_stats = progress_stats + self.update_stats = step_method._make_update_stats_function() + + self._show_progress = show_progress + self.divergences = 0 + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + def compute_draw_speed(self, chain_idx, draws): + elapsed = self._progress.tasks[chain_idx].elapsed + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + speed, unit = self.compute_draw_speed(chain_idx, draw) + + if not tuning and stats and stats[0].get("diverging"): + self.divergences += 1 + + self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) + more_updates = ( + {stat: value[chain_idx] for stat, value in self.progress_stats.items()} + if self.full_stats + else {} + ) + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + **more_updates, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw, + **more_updates, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) + + +def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"])