Skip to content

Add return type overload for sample_posterior_predictive #7710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,36 @@ def sample_prior_predictive(
return pm.to_inference_data(prior=prior, **ikwargs)


@overload
def sample_posterior_predictive(
trace,
model: Model | None = None,
var_names: list[str] | None = None,
sample_dims: list[str] | None = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
return_inferencedata: Literal[True] = True,
extend_inferencedata: bool = False,
predictions: bool = False,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
) -> InferenceData: ...
@overload
def sample_posterior_predictive(
trace,
model: Model | None = None,
var_names: list[str] | None = None,
sample_dims: list[str] | None = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
return_inferencedata: Literal[False] = False,
extend_inferencedata: bool = False,
predictions: bool = False,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
) -> dict[str, np.ndarray]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a MultiTrace type alias?

Copy link
Contributor Author

@nataziel nataziel Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could import the class from backends/base. Is the return type actually a MultiTrace? Or do you want a local alias for MultiTrace = dict[str, np.ndarray]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was a multi-trace, maybe it's a dict indeed. Can you double check

Copy link
Contributor Author

@nataziel nataziel Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely dict[str, np.ndarray]. Can confirm with this:

import pymc as pm
import numpy as np


def main():
    data = [5.1, 5.2, 4.9, 4.8]
    with pm.Model() as model:
        target_value = pm.Data(name="target_y", value=data, dims=("x"))
        a = pm.Normal("a")

        y_hat = pm.Deterministic("y_hat", var=a + 5)

        y_like = pm.Normal("y_like", mu=y_hat, observed=target_value)

        my_model = model

    fit_trace = pm.sample(model=my_model, tune=10, draws=10, chains=4)

    print(fit_trace)
    print(type(fit_trace))

    y_posterior_trace = pm.sample_posterior_predictive(
        model=my_model, trace=fit_trace, return_inferencedata=False
    )

    print(y_posterior_trace)
    print(type(y_posterior_trace))


if __name__ == "__main__":
    main()
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a]

  Progress                                   Draws   Divergences   Step size   Grad evals   Sampling Speed    Elapsed   Remaining
 ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.84        3            1824.91 draws/s   0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.80        3            51.02 draws/s     0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.28        1            50.07 draws/s     0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.15        7            14.61 draws/s     0:00:01   0:00:00

Sampling 4 chains for 10 tune and 10 draw iterations (40 + 40 draws total) took 9 seconds.
The number of samples is too small to check convergence reliably.
Inference data with groups:
        > posterior
        > sample_stats
        > observed_data
        > constant_data
<class 'arviz.data.inference_data.InferenceData'>
Sampling: [y_like]
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00
{'y_like': array([[[5.47525026, 6.7561118 , 6.75108787, 6.44631066],
        [3.08866384, 4.44911318, 2.11949771, 3.58852593],
        [5.78696067, 4.4653521 , 4.18029713, 4.25236639],
        [5.60767872, 4.8688315 , 3.44183177, 5.38761565],
        [4.84709411, 4.66473999, 5.96239288, 5.14296245],
        [4.45261563, 5.01725227, 6.03797216, 4.10837762],
        [6.46749518, 4.55590472, 6.02580747, 4.45176423],
        [3.97070401, 4.67813175, 4.55127384, 5.58474455],
        [5.23346119, 5.09860006, 5.40714872, 3.155104  ],
        [5.01437042, 4.73579148, 4.69969069, 4.73011345]],

       [[6.95786381, 5.44863886, 4.67877711, 6.07620478],
        [6.1479291 , 3.63454873, 4.26082608, 5.05364968],
        [4.94426014, 3.2264388 , 3.49036617, 3.51758425],
        [4.46372822, 4.97982756, 4.71369595, 4.28042535],
        [3.67790877, 4.61166178, 4.859494  , 3.90743623],
        [5.30186549, 5.51060686, 5.72649511, 5.44831013],
        [5.6824159 , 5.00966824, 5.81942202, 7.10113269],
        [6.69604693, 6.22185714, 4.66787917, 6.93183407],
        [2.69863789, 3.15122392, 3.93177678, 5.56284008],
        [5.42486489, 5.06666397, 5.31683066, 3.38231024]],

       [[7.55508934, 6.12556656, 5.7270704 , 7.26913077],
        [3.99491984, 4.29239014, 5.54759873, 4.23275301],
        [2.66913545, 4.87218678, 5.62928131, 5.30816862],
        [4.82329652, 5.23746588, 4.49146018, 5.63991739],
        [4.92370427, 5.95452363, 6.03808899, 5.94568277],
        [5.10762579, 7.12951932, 5.75341415, 5.84696291],
        [3.37979901, 3.62339437, 4.65771825, 3.7923768 ],
        [3.62867257, 2.94668575, 3.61802996, 4.1612957 ],
        [3.05870501, 4.31849566, 4.2617358 , 5.47176507],
        [4.20927715, 3.17447563, 3.20495462, 3.75464106]],

       [[5.35709662, 6.54309611, 3.40905019, 4.99514321],
        [4.92966164, 4.4250744 , 3.70201706, 4.31879397],
        [2.78342991, 4.88726609, 6.05864548, 4.551451  ],
        [3.37619167, 4.62058509, 5.80385571, 3.6307474 ],
        [3.92182737, 5.22549831, 5.04820779, 4.96424299],
        [5.59346296, 5.64233644, 6.81086527, 5.35190658],
        [5.13364673, 6.42014636, 5.2321678 , 3.9226347 ],
        [3.67539437, 3.52974516, 6.81766322, 5.42133207],
        [3.65774333, 4.45101107, 4.96830193, 3.83520201],
        [4.21230467, 3.95171007, 5.6626775 , 6.54087073]]])}
<class 'dict'>

def sample_posterior_predictive(
trace,
model: Model | None = None,
Expand Down
Loading