35
35
_support_point ,
36
36
support_point ,
37
37
)
38
+ from pymc .distributions .mixture import _HurdleRV
38
39
from pymc .distributions .shape_utils import (
39
40
_change_dist_size ,
40
41
change_dist_size ,
@@ -79,7 +80,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
79
80
80
81
# Try to use specialized Op
81
82
try :
82
- return _truncated (dist .owner .op , lower , upper , size , * dist .owner .inputs )
83
+ return _truncated (
84
+ dist .owner .op , lower , upper , size , * dist .owner .inputs , max_n_steps = max_n_steps
85
+ )
83
86
except NotImplementedError :
84
87
pass
85
88
@@ -222,7 +225,7 @@ def update(self, node: Apply):
222
225
223
226
224
227
@singledispatch
225
- def _truncated (op : Op , lower , upper , size , * params ):
228
+ def _truncated (op : Op , lower , upper , size , * params , max_n_steps : int ):
226
229
"""Return the truncated equivalent of another `RandomVariable`."""
227
230
raise NotImplementedError (f"{ op } does not have an equivalent truncated version implemented" )
228
231
@@ -307,13 +310,14 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)
307
310
f"Truncation dist must be a distribution created via the `.dist()` API, got { type (dist )} "
308
311
)
309
312
310
- if (
311
- isinstance (dist .owner .op , SymbolicRandomVariable )
312
- and "[size]" not in dist .owner .op .extended_signature
313
+ if isinstance (dist .owner .op , SymbolicRandomVariable ) and not (
314
+ "[size]" in dist .owner .op .extended_signature
315
+ # If there's a specific _truncated dispatch for this RV, that's also fine
316
+ or _truncated .dispatch (type (dist .owner .op )) is not _truncated .dispatch (object )
313
317
):
314
318
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
315
319
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
316
- # distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
320
+ # distribution factories like `Censored` which would have a very complex signature if they
317
321
# encapsulated the random components instead of taking them as inputs like they do now.
318
322
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
319
323
raise NotImplementedError (f"Truncation not implemented for { dist .owner .op } " )
@@ -462,7 +466,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
462
466
463
467
464
468
@_truncated .register (NormalRV )
465
- def _truncated_normal (op , lower , upper , size , rng , old_size , mu , sigma ):
469
+ def _truncated_normal (op , lower , upper , size , rng , old_size , mu , sigma , * , max_n_steps ):
466
470
return TruncatedNormal .dist (
467
471
mu = mu ,
468
472
sigma = sigma ,
@@ -472,3 +476,32 @@ def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
472
476
size = size ,
473
477
dtype = op .dtype ,
474
478
)
479
+
480
+
481
+ @_truncated .register (_HurdleRV )
482
+ def _truncated_hurdle (
483
+ op : _HurdleRV , lower , upper , size , rng , weights , zero_dist , dist , max_n_steps
484
+ ):
485
+ # If the DiracDelta value is outside the truncation bounds, this is effectively a non-hurdle distribution
486
+ # We achieve this by adjusting the weights of the DiracDelta component, so it's never selected in that case
487
+ psi = weights [..., 1 ]
488
+
489
+ checks = np .array (True )
490
+ if lower is not None :
491
+ checks &= lower <= 0
492
+ if upper is not None :
493
+ checks &= 0 <= upper
494
+
495
+ adjusted_psi = pt .where (
496
+ checks ,
497
+ psi ,
498
+ 1 ,
499
+ )
500
+ adjusted_weights = pt .stack ([1 - adjusted_psi , adjusted_psi ], axis = - 1 )
501
+
502
+ # The only remaining step is to truncate the other distribution
503
+ truncated_dist = Truncated .dist (dist , lower = lower , upper = upper , max_n_steps = max_n_steps )
504
+
505
+ # Creating a hurdle with the adjusted weights and the truncated distribution
506
+ # Should be equivalent to truncating the original hurdle distribution
507
+ return op .rv_op (adjusted_weights , zero_dist , truncated_dist , size = size )
0 commit comments