diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 8e52c812d1..8670a0aa1e 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings +from typing import Optional + import numpy as np import pytensor.tensor as pt @@ -45,7 +47,7 @@ normal_lccdf, normal_lcdf, ) -from pymc.distributions.distribution import Discrete +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete from pymc.distributions.mixture import Mixture from pymc.distributions.shape_utils import rv_size_is_none from pymc.logprob.basic import logp @@ -122,7 +124,14 @@ class Binomial(Discrete): rv_op = binomial @classmethod - def dist(cls, n, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + n: DIST_PARAMETER_TYPES, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -238,7 +247,14 @@ def BetaBinom(a, b, n, x): rv_op = betabinom @classmethod - def dist(cls, alpha, beta, n, *args, **kwargs): + def dist( + cls, + alpha: DIST_PARAMETER_TYPES, + beta: DIST_PARAMETER_TYPES, + n: DIST_PARAMETER_TYPES, + *args, + **kwargs, + ): alpha = pt.as_tensor_variable(floatX(alpha)) beta = pt.as_tensor_variable(floatX(beta)) n = pt.as_tensor_variable(intX(n)) @@ -344,7 +360,13 @@ class Bernoulli(Discrete): rv_op = bernoulli @classmethod - def dist(cls, p=None, logit_p=None, *args, **kwargs): + def dist( + cls, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -460,7 +482,7 @@ def DiscreteWeibull(q, b, x): rv_op = discrete_weibull @classmethod - def dist(cls, q, beta, *args, **kwargs): + def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs): q = pt.as_tensor_variable(floatX(q)) beta = pt.as_tensor_variable(floatX(beta)) return super().dist([q, beta], **kwargs) @@ -549,7 +571,7 @@ class Poisson(Discrete): rv_op = poisson @classmethod - def dist(cls, mu, *args, **kwargs): + def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs): mu = pt.as_tensor_variable(floatX(mu)) return super().dist([mu], *args, **kwargs) @@ -671,7 +693,15 @@ def NegBinom(a, m, x): rv_op = nbinom @classmethod - def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs): + def dist( + cls, + mu: Optional[DIST_PARAMETER_TYPES] = None, + alpha: Optional[DIST_PARAMETER_TYPES] = None, + p: Optional[DIST_PARAMETER_TYPES] = None, + n: Optional[DIST_PARAMETER_TYPES] = None, + *args, + **kwargs, + ): n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n) n = pt.as_tensor_variable(floatX(n)) p = pt.as_tensor_variable(floatX(p)) @@ -784,7 +814,7 @@ class Geometric(Discrete): rv_op = geometric @classmethod - def dist(cls, p, *args, **kwargs): + def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs): p = pt.as_tensor_variable(floatX(p)) return super().dist([p], *args, **kwargs) @@ -881,7 +911,14 @@ class HyperGeometric(Discrete): rv_op = hypergeometric @classmethod - def dist(cls, N, k, n, *args, **kwargs): + def dist( + cls, + N: DIST_PARAMETER_TYPES, + k: DIST_PARAMETER_TYPES, + n: DIST_PARAMETER_TYPES, + *args, + **kwargs, + ): good = pt.as_tensor_variable(intX(k)) bad = pt.as_tensor_variable(intX(N - k)) n = pt.as_tensor_variable(intX(n)) @@ -1018,7 +1055,7 @@ class DiscreteUniform(Discrete): rv_op = discrete_uniform @classmethod - def dist(cls, lower, upper, *args, **kwargs): + def dist(cls, lower: DIST_PARAMETER_TYPES, upper: DIST_PARAMETER_TYPES, *args, **kwargs): lower = intX(pt.floor(lower)) upper = intX(pt.floor(upper)) return super().dist([lower, upper], **kwargs) @@ -1108,7 +1145,12 @@ class Categorical(Discrete): rv_op = categorical @classmethod - def dist(cls, p=None, logit_p=None, **kwargs): + def dist( + cls, + p: Optional[DIST_PARAMETER_TYPES] = None, + logit_p: Optional[DIST_PARAMETER_TYPES] = None, + **kwargs, + ): if p is not None and logit_p is not None: raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") elif p is None and logit_p is None: @@ -1210,7 +1252,7 @@ class DiracDelta(Discrete): rv_op = diracdelta @classmethod - def dist(cls, c, *args, **kwargs): + def dist(cls, c: DIST_PARAMETER_TYPES, *args, **kwargs): c = pt.as_tensor_variable(c) if c.dtype in continuous_types: c = floatX(c) @@ -1328,7 +1370,7 @@ def __new__(cls, name, psi, mu, **kwargs): ) @classmethod - def dist(cls, psi, mu, **kwargs): + def dist(cls, psi: DIST_PARAMETER_TYPES, mu: DIST_PARAMETER_TYPES, **kwargs): return _zero_inflated_mixture( name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs ) @@ -1393,7 +1435,9 @@ def __new__(cls, name, psi, n, p, **kwargs): ) @classmethod - def dist(cls, psi, n, p, **kwargs): + def dist( + cls, psi: DIST_PARAMETER_TYPES, n: DIST_PARAMETER_TYPES, p: DIST_PARAMETER_TYPES, **kwargs + ): return _zero_inflated_mixture( name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs ) @@ -1490,7 +1534,15 @@ def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs): ) @classmethod - def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs): + def dist( + cls, + psi: DIST_PARAMETER_TYPES, + mu: Optional[DIST_PARAMETER_TYPES] = None, + alpha: Optional[DIST_PARAMETER_TYPES] = None, + p: Optional[DIST_PARAMETER_TYPES] = None, + n: Optional[DIST_PARAMETER_TYPES] = None, + **kwargs, + ): return _zero_inflated_mixture( name=None, nonzero_p=psi, @@ -1507,7 +1559,7 @@ class _OrderedLogistic(Categorical): rv_op = categorical @classmethod - def dist(cls, eta, cutpoints, *args, **kwargs): + def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs): eta = pt.as_tensor_variable(floatX(eta)) cutpoints = pt.as_tensor_variable(cutpoints) @@ -1613,7 +1665,14 @@ class _OrderedProbit(Categorical): rv_op = categorical @classmethod - def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs): + def dist( + cls, + eta: DIST_PARAMETER_TYPES, + cutpoints: DIST_PARAMETER_TYPES, + sigma: DIST_PARAMETER_TYPES = 1, + *args, + **kwargs, + ): eta = pt.as_tensor_variable(floatX(eta)) cutpoints = pt.as_tensor_variable(cutpoints)