diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index b2cbff9b68..1759eb85b8 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -669,6 +669,7 @@ def sample_jax_nuts( random_seed=random_seed, initial_points=initial_points, nuts_kwargs=nuts_kwargs, + logp_fn=logp_fn, ) tic2 = datetime.now()