Skip to content

Commit d6f63c4

Browse files
committed
Add initial distribution parameter types
1 parent 75ea2a8 commit d6f63c4

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,9 @@ def logcdf(value, mu, sigma):
583583
)
584584

585585

586+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES
587+
588+
586589
class TruncatedNormalRV(RandomVariable):
587590
name = "truncated_normal"
588591
ndim_supp = 0
@@ -594,7 +597,7 @@ class TruncatedNormalRV(RandomVariable):
594597
def rng_fn(
595598
cls,
596599
rng: np.random.RandomState,
597-
mu: Union[np.ndarray, float],
600+
mu: DIST_PARAMETER_TYPES,
598601
sigma: Union[np.ndarray, float],
599602
lower: Union[np.ndarray, float],
600603
upper: Union[np.ndarray, float],

pymc/distributions/distribution.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Optional, Sequence
22+
from typing import Callable, Optional, Sequence, Union
2323

2424
import aesara
25+
import numpy as np
2526

2627
from aeppl.logprob import _logcdf, _logprob
2728
from aesara import tensor as at
@@ -56,6 +57,8 @@
5657
"NoDistribution",
5758
]
5859

60+
DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]
61+
5962
vectorized_ppc = contextvars.ContextVar(
6063
"vectorized_ppc", default=None
6164
) # type: contextvars.ContextVar[Optional[Callable]]

0 commit comments

Comments
 (0)