Skip to content

Commit ac73b55

Browse files
committed
Allow truncation of hurdle distributions
1 parent 3f42edd commit ac73b55

File tree

3 files changed

+70
-8
lines changed

3 files changed

+70
-8
lines changed

pymc/distributions/mixture.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
)
4040
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, rv_size_is_none
4141
from pymc.distributions.transforms import _default_transform
42-
from pymc.distributions.truncated import Truncated
4342
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob
4443
from pymc.logprob.basic import logp
4544
from pymc.logprob.transforms import IntervalTransform
@@ -831,6 +830,8 @@ def _create(cls, *, name, nonzero_p, nonzero_dist, max_n_steps=10_000, **kwargs)
831830
832831
Note: this is invalid for discrete nonzero distributions with mass below 0, as we simply truncate[lower=1].
833832
"""
833+
from pymc.distributions.truncated import Truncated
834+
834835
dtype = nonzero_dist.dtype
835836

836837
if dtype.startswith("int"):

pymc/distributions/truncated.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_support_point,
3636
support_point,
3737
)
38+
from pymc.distributions.mixture import _HurdleRV
3839
from pymc.distributions.shape_utils import (
3940
_change_dist_size,
4041
change_dist_size,
@@ -79,7 +80,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
7980

8081
# Try to use specialized Op
8182
try:
82-
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
83+
return _truncated(
84+
dist.owner.op, lower, upper, size, *dist.owner.inputs, max_n_steps=max_n_steps
85+
)
8386
except NotImplementedError:
8487
pass
8588

@@ -222,7 +225,7 @@ def update(self, node: Apply):
222225

223226

224227
@singledispatch
225-
def _truncated(op: Op, lower, upper, size, *params):
228+
def _truncated(op: Op, lower, upper, size, *params, max_n_steps: int):
226229
"""Return the truncated equivalent of another `RandomVariable`."""
227230
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
228231

@@ -307,13 +310,14 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)
307310
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
308311
)
309312

310-
if (
311-
isinstance(dist.owner.op, SymbolicRandomVariable)
312-
and "[size]" not in dist.owner.op.extended_signature
313+
if isinstance(dist.owner.op, SymbolicRandomVariable) and not (
314+
"[size]" in dist.owner.op.extended_signature
315+
# If there's a specific _truncated dispatch for this RV, that's also fine
316+
or _truncated.dispatch(type(dist.owner.op)) is not _truncated.dispatch(object)
313317
):
314318
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
315319
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
316-
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
320+
# distribution factories like `Censored` which would have a very complex signature if they
317321
# encapsulated the random components instead of taking them as inputs like they do now.
318322
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
319323
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")
@@ -462,7 +466,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
462466

463467

464468
@_truncated.register(NormalRV)
465-
def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
469+
def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma, *, max_n_steps):
466470
return TruncatedNormal.dist(
467471
mu=mu,
468472
sigma=sigma,
@@ -472,3 +476,32 @@ def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
472476
size=size,
473477
dtype=op.dtype,
474478
)
479+
480+
481+
@_truncated.register(_HurdleRV)
482+
def _truncated_hurdle(
483+
op: _HurdleRV, lower, upper, size, rng, weights, zero_dist, dist, max_n_steps
484+
):
485+
# If the DiracDelta value is outside the truncation bounds, this is effectively a non-hurdle distribution
486+
# We achieve this by adjusting the weights of the DiracDelta component, so it's never selected in that case
487+
psi = weights[..., 1]
488+
489+
checks = np.array(True)
490+
if lower is not None:
491+
checks &= lower <= 0
492+
if upper is not None:
493+
checks &= 0 <= upper
494+
495+
adjusted_psi = pt.where(
496+
checks,
497+
psi,
498+
1,
499+
)
500+
adjusted_weights = pt.stack([1 - adjusted_psi, adjusted_psi], axis=-1)
501+
502+
# The only remaining step is to truncate the other distribution
503+
truncated_dist = Truncated.dist(dist, lower=lower, upper=upper, max_n_steps=max_n_steps)
504+
505+
# Creating a hurdle with the adjusted weights and the truncated distribution
506+
# Should be equivalent to truncating the original hurdle distribution
507+
return op.rv_op(adjusted_weights, zero_dist, truncated_dist, size=size)

tests/distributions/test_mixture.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Poisson,
5050
StickBreakingWeights,
5151
Triangular,
52+
Truncated,
5253
Uniform,
5354
ZeroInflatedBinomial,
5455
ZeroInflatedNegativeBinomial,
@@ -1710,3 +1711,30 @@ def logp_fn(value, psi, mu, sigma):
17101711
return np.log(psi) + st.lognorm.logpdf(value, sigma, 0, np.exp(mu))
17111712

17121713
check_logp(HurdleLogNormal, Rplus, {"psi": Unit, "mu": R, "sigma": Rplusbig}, logp_fn)
1714+
1715+
@pytest.mark.parametrize("lower", (-np.inf, 0, None, 1))
1716+
def test_truncated_hurdle_lognormal(self, lower):
1717+
psi = 0.7
1718+
x = HurdleLogNormal.dist(psi=psi, mu=3, sigma=1)
1719+
x_trunc = Truncated.dist(x, lower=lower, upper=30, size=(1000,))
1720+
1721+
x_trunc_draws = draw(x_trunc)
1722+
assert ((x_trunc_draws >= (lower or -np.inf)) & (x_trunc_draws <= 30)).all()
1723+
1724+
x_trunc = Truncated.dist(x, lower=lower, upper=30, size=(4,))
1725+
x_trunc_logp = logp(x_trunc, [0, 5.5, 30.0, 30.1]).eval()
1726+
effective_psi = psi if (lower or -np.inf) <= 0 else 1
1727+
np.testing.assert_allclose(
1728+
x_trunc_logp,
1729+
[
1730+
np.log(1 - effective_psi), # 0 is not in the support of the distribution
1731+
*(
1732+
np.log(effective_psi)
1733+
+ logp(
1734+
Truncated.dist(LogNormal.dist(mu=3, sigma=1), lower=lower, upper=30),
1735+
[5.5, 30.0],
1736+
)
1737+
).eval(),
1738+
-np.inf, # 30.1 is outside the upper bound
1739+
],
1740+
)

0 commit comments

Comments
 (0)