From cc3c675f7f2b0624ce5b37a45d7158f8dcfbc71c Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 6 Dec 2021 18:54:18 +0100
Subject: [PATCH 1/6] Pin aeppl version

---
 conda-envs/environment-dev-py37.yml          | 2 +-
 conda-envs/environment-dev-py38.yml          | 2 +-
 conda-envs/environment-dev-py39.yml          | 2 +-
 conda-envs/environment-test-py37.yml         | 2 +-
 conda-envs/environment-test-py38.yml         | 2 +-
 conda-envs/environment-test-py39.yml         | 2 +-
 conda-envs/windows-environment-dev-py38.yml  | 2 +-
 conda-envs/windows-environment-test-py38.yml | 2 +-
 requirements-dev.txt                         | 2 +-
 requirements.txt                             | 2 +-
 10 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml
index 7263fb4245..90e22cf6b7 100644
--- a/conda-envs/environment-dev-py37.yml
+++ b/conda-envs/environment-dev-py37.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml
index 21858097ce..69fd28473b 100644
--- a/conda-envs/environment-dev-py38.yml
+++ b/conda-envs/environment-dev-py38.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml
index 0b007a5aa2..6a0bfc917d 100644
--- a/conda-envs/environment-dev-py39.yml
+++ b/conda-envs/environment-dev-py39.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml
index c5a4a14b47..964939f493 100644
--- a/conda-envs/environment-test-py37.yml
+++ b/conda-envs/environment-test-py37.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml
index 5a2b1ee762..cec0b1ffa2 100644
--- a/conda-envs/environment-test-py38.yml
+++ b/conda-envs/environment-test-py38.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml
index 942c05c03d..50d87383c6 100644
--- a/conda-envs/environment-test-py39.yml
+++ b/conda-envs/environment-test-py39.yml
@@ -4,7 +4,7 @@ channels:
 - conda-forge
 - defaults
 dependencies:
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools
diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml
index 513fe2e2af..ffac106acf 100644
--- a/conda-envs/windows-environment-dev-py38.yml
+++ b/conda-envs/windows-environment-dev-py38.yml
@@ -4,7 +4,7 @@ channels:
 - defaults
 dependencies:
  # base dependencies (see install guide for Windows)
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.4
 - cachetools>=4.2.1
diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml
index 47374aa425..6ca7b966dd 100644
--- a/conda-envs/windows-environment-test-py38.yml
+++ b/conda-envs/windows-environment-test-py38.yml
@@ -4,7 +4,7 @@ channels:
 - defaults
 dependencies:
  # base dependencies (see install guide for Windows)
-- aeppl>=0.0.17
+- aeppl=0.0.17
 - aesara>=2.2.6
 - arviz>=0.11.2
 - cachetools
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 81509ea957..c5fbff991b 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,7 +1,7 @@
 # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify.
 # See that file for comments about the need/usage of each dependency.
 
-aeppl>=0.0.17
+aeppl==0.0.17
 aesara>=2.2.6
 arviz>=0.11.4
 cachetools>=4.2.1
diff --git a/requirements.txt b/requirements.txt
index d80b9c8a46..75da478008 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-aeppl>=0.0.17
+aeppl==0.0.17
 aesara>=2.2.6
 arviz>=0.11.4
 cachetools>=4.2.1

From eb925bcd97e111f9d6f272c1b72e476fa59bfe3e Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 6 Dec 2021 18:16:11 +0100
Subject: [PATCH 2/6] Simplify DirichletMultinomial logp and remove restriction
 on dimensionality of n and a

Refactor vectorized logp tests
---
 pymc/distributions/multivariate.py |  29 ++-----
 pymc/tests/test_distributions.py   | 129 ++++++-----------------------
 2 files changed, 35 insertions(+), 123 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 9c3bd37a91..b2f5384f2c 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -626,17 +626,12 @@ class DirichletMultinomial(Discrete):
 
     Parameters
     ----------
-    n : int or array
-        Total counts in each replicate. If n is an array its shape must be (N,)
-        with N = a.shape[0]
+    n : int
+        Total counts in each replicate.
 
-    a : one- or two-dimensional array
-        Dirichlet parameter. Elements must be strictly positive.
-        The number of categories is given by the length of the last axis.
-
-    shape : integer tuple
-        Describes shape of distribution. For example if n=array([5, 10]), and
-        a=array([1, 1, 1]), shape should be (2, 3).
+    a : vector
+        Dirichlet alpha parameter. Elements must be strictly positive. The number of
+        categories is given by the length of the last axis.
     """
     rv_op = dirichlet_multinomial
 
@@ -661,15 +656,10 @@ def logp(value, n, a):
         -------
         TensorVariable
         """
-        if value.ndim >= 1:
-            n = at.shape_padright(n)
-            if a.ndim > 1:
-                a = at.shape_padleft(a)
-
-        sum_a = a.sum(axis=-1, keepdims=True)
+        sum_a = a.sum(axis=-1)
         const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
         series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
-        result = const + series.sum(axis=-1, keepdims=True)
+        result = const + series.sum(axis=-1)
 
         # Bounds checking to confirm parameters and data meet all constraints
         # and that each observation value_i sums to n_i.
@@ -678,13 +668,10 @@ def logp(value, n, a):
             value >= 0,
             a > 0,
             n >= 0,
-            at.eq(value.sum(axis=-1, keepdims=True), n),
+            at.eq(value.sum(axis=-1), n),
             broadcast_conditions=False,
         )
 
-    def _distr_parameters_for_repr(self):
-        return ["n", "a"]
-
 
 class _OrderedMultinomial(Multinomial):
     r"""
diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index b19f24543b..602d38919e 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -50,7 +50,7 @@ def polyagamma_cdf(*args, **kwargs):
 from numpy import array, inf, log
 from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
 from scipy import integrate
-from scipy.special import erf, logit
+from scipy.special import erf, gammaln, logit
 
 import pymc as pm
 
@@ -337,21 +337,21 @@ def multinomial_logpdf(value, n, p):
         return -inf
 
 
-def dirichlet_multinomial_logpmf(value, n, a):
-    value, n, a = (np.asarray(x) for x in [value, n, a])
-    assert value.ndim == 1
-    assert n.ndim == 0
-    assert a.shape == value.shape
-    gammaln = scipy.special.gammaln
+def _dirichlet_multinomial_logpmf(value, n, a):
     if value.sum() == n and (0 <= value).all() and (value <= n).all():
-        sum_a = a.sum(axis=-1)
+        sum_a = a.sum()
         const = gammaln(n + 1) + gammaln(sum_a) - gammaln(n + sum_a)
         series = gammaln(value + a) - gammaln(value + 1) - gammaln(a)
-        return const + series.sum(axis=-1)
+        return const + series.sum()
     else:
         return -inf
 
 
+dirichlet_multinomial_logpmf = np.vectorize(
+    _dirichlet_multinomial_logpmf, signature="(n),(),(n)->()"
+)
+
+
 def beta_mu_sigma(value, mu, sigma):
     kappa = mu * (1 - mu) / sigma ** 2 - 1
     if kappa > 0:
@@ -2314,105 +2314,30 @@ def test_dirichlet_multinomial_matches_beta_binomial(self):
             decimal=select_by_precision(float64=6, float32=3),
         )
 
-    def test_dirichlet_multinomial_vec(self):
-        vals = np.array([[2, 4, 4], [3, 3, 4]])
-        a = np.array([0.2, 0.3, 0.5])
-        n = 10
-
-        with Model() as model_single:
-            DirichletMultinomial("m", n=n, a=a)
-
-        with Model() as model_many:
-            DirichletMultinomial("m", n=n, a=a, size=2)
-
-        assert_almost_equal(
-            np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
-            np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
-            decimal=4,
-        )
-
-        assert_almost_equal(
-            np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
-            logp(model_many.m, vals).eval().squeeze(),
-            decimal=4,
-        )
-
-        assert_almost_equal(
-            sum(model_single.fastlogp({"m": val}) for val in vals),
-            model_many.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_dirichlet_multinomial_vec_1d_n(self):
-        vals = np.array([[2, 4, 4], [4, 3, 4]])
-        a = np.array([0.2, 0.3, 0.5])
-        ns = np.array([10, 11])
-
-        with Model() as model:
-            DirichletMultinomial("m", n=ns, a=a)
-
-        assert_almost_equal(
-            sum(dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)),
-            model.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
-        vals = np.array([[2, 4, 4], [4, 3, 4]])
-        as_ = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
-        ns = np.array([10, 11])
-
-        with Model() as model:
-            DirichletMultinomial("m", n=ns, a=as_)
-
-        assert_almost_equal(
-            sum(dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)),
-            model.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_dirichlet_multinomial_vec_2d_a(self):
-        vals = np.array([[2, 4, 4], [3, 3, 4]])
-        as_ = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
-        n = 10
+    @pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
+    @pytest.mark.parametrize(
+        "a",
+        [
+            ([0.2, 0.3, 0.5]),
+            ([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]),
+            (np.abs(np.random.randn(2, 2, 4))),
+        ],
+    )
+    @pytest.mark.parametrize("size", [1, 2, (2, 3)])
+    def test_dirichlet_multinomial_vectorized(self, n, a, size):
+        n = intX(np.array(n))
+        a = floatX(np.array(a))
 
-        with Model() as model:
-            DirichletMultinomial("m", n=n, a=as_)
+        dm = pm.DirichletMultinomial.dist(n=n, a=a, size=size)
+        vals = dm.eval()
 
         assert_almost_equal(
-            sum(dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)),
-            model.fastlogp({"m": vals}),
+            dirichlet_multinomial_logpmf(vals, n, a),
+            pm.logp(dm, vals).eval(),
             decimal=4,
+            err_msg=f"vals={vals}",
         )
 
-    def test_batch_dirichlet_multinomial(self):
-        # Test that DM can handle a 3d array for `a`
-
-        # Create an almost deterministic DM by setting a to 0.001, everywhere
-        # except for one category / dimension which is given the value of 1000
-        n = 5
-        vals = np.zeros((4, 5, 3), dtype="int32")
-        a = np.zeros_like(vals, dtype=aesara.config.floatX) + 0.001
-        inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
-        np.put_along_axis(vals, inds, n, axis=-1)
-        np.put_along_axis(a, inds, 1000, axis=-1)
-
-        dist = DirichletMultinomial.dist(n=n, a=a)
-
-        # Logp should be approx -9.98004998e-06
-        dist_logp = logp(dist, vals).eval()
-        expected_logp = np.full_like(dist_logp, fill_value=-9.98004998e-06)
-        assert_almost_equal(
-            dist_logp,
-            expected_logp,
-            decimal=select_by_precision(float64=6, float32=3),
-        )
-
-        # Samples should be equal given the almost deterministic DM
-        dist = DirichletMultinomial.dist(n=n, a=a, size=2)
-        sample = dist.eval()
-        assert_allclose(sample, np.stack([vals, vals], axis=0))
-
     @aesara.config.change_flags(compute_test_value="raise")
     def test_categorical_bounds(self):
         with Model():

From 248131d4a3225252fb6dfa9ce7da0a2f5068308e Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 6 Dec 2021 18:18:46 +0100
Subject: [PATCH 3/6] Remove legacy docstrings Multinomial restriction on
 dimensionality of n and p

Refactor vectorized logp tests
---
 pymc/distributions/multivariate.py |  12 ++-
 pymc/tests/test_distributions.py   | 129 ++++++-----------------------
 2 files changed, 31 insertions(+), 110 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index b2f5384f2c..d01662320d 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -507,13 +507,11 @@ class Multinomial(Discrete):
 
     Parameters
     ----------
-    n: int or array
-        Number of trials (n > 0). If n is an array its shape must be (N,) with
-        N = p.shape[0]
-    p: one- or two-dimensional array
-        Probability of each one of the different outcomes. Elements must
-        be non-negative and sum to 1 along the last axis. They will be
-        automatically rescaled otherwise.
+    n: int
+        Number of trials (n > 0)
+    p: vector
+        Probability of each one of the different outcomes. Elements must be non-negative
+        and sum to 1 along the last axis. They will be automatically rescaled otherwise.
     """
     rv_op = multinomial
 
diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index 602d38919e..ec67fcf7c7 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -48,7 +48,7 @@ def polyagamma_cdf(*args, **kwargs):
 from aesara.tensor.random.op import RandomVariable
 from aesara.tensor.var import TensorVariable
 from numpy import array, inf, log
-from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
+from numpy.testing import assert_almost_equal, assert_equal
 from scipy import integrate
 from scipy.special import erf, gammaln, logit
 
@@ -327,16 +327,6 @@ def f3(a, b, c):
         raise ValueError("Dont know how to integrate shape: " + str(shape))
 
 
-def multinomial_logpdf(value, n, p):
-    if value.sum() == n and (0 <= value).all() and (value <= n).all():
-        logpdf = scipy.special.gammaln(n + 1)
-        logpdf -= scipy.special.gammaln(value + 1).sum()
-        logpdf += logpow(p, value).sum()
-        return logpdf
-    else:
-        return -inf
-
-
 def _dirichlet_multinomial_logpmf(value, n, a):
     if value.sum() == n and (0 <= value).all() and (value <= n).all():
         sum_a = a.sum()
@@ -2157,7 +2147,10 @@ def test_dirichlet_2D(self):
     @pytest.mark.parametrize("n", [2, 3])
     def test_multinomial(self, n):
         self.check_logp(
-            Multinomial, Vector(Nat, n), {"p": Simplex(n), "n": Nat}, multinomial_logpdf
+            Multinomial,
+            Vector(Nat, n),
+            {"p": Simplex(n), "n": Nat},
+            lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
         )
 
     @pytest.mark.parametrize(
@@ -2187,106 +2180,36 @@ def test_multinomial_random(self, p, size, n):
 
         assert m.eval().shape == size + p.shape
 
-    def test_multinomial_vec(self):
-        vals = np.array([[2, 4, 4], [3, 3, 4]])
-        p = np.array([0.2, 0.3, 0.5])
-        n = 10
-
-        with Model() as model_single:
-            Multinomial("m", n=n, p=p)
-
-        with Model() as model_many:
-            Multinomial("m", n=n, p=p, size=2)
+    @pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
+    @pytest.mark.parametrize(
+        "p",
+        [
+            ([0.2, 0.3, 0.5]),
+            ([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]),
+            (np.abs(np.random.randn(2, 2, 4))),
+        ],
+    )
+    @pytest.mark.parametrize("size", [1, 2, (2, 3)])
+    def test_multinomial_vectorized(self, n, p, size):
+        n = intX(np.array(n))
+        p = floatX(np.array(p))
+        p /= p.sum(axis=-1, keepdims=True)
 
-        assert_almost_equal(
-            scipy.stats.multinomial.logpmf(vals, n, p),
-            np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
-            decimal=4,
-        )
+        mn = pm.Multinomial.dist(n=n, p=p, size=size)
+        vals = mn.eval()
 
         assert_almost_equal(
             scipy.stats.multinomial.logpmf(vals, n, p),
-            logp(model_many.m, vals).eval().squeeze(),
+            pm.logp(mn, vals).eval(),
             decimal=4,
+            err_msg=f"vals={vals}",
         )
 
-        assert_almost_equal(
-            sum(model_single.fastlogp({"m": val}) for val in vals),
-            model_many.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_multinomial_vec_1d_n(self):
-        vals = np.array([[2, 4, 4], [4, 3, 4]])
-        p = np.array([0.2, 0.3, 0.5])
-        ns = np.array([10, 11])
-
-        with Model() as model:
-            Multinomial("m", n=ns, p=p)
-
-        assert_almost_equal(
-            sum(multinomial_logpdf(val, n, p) for val, n in zip(vals, ns)),
-            model.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_multinomial_vec_1d_n_2d_p(self):
-        vals = np.array([[2, 4, 4], [4, 3, 4]])
-        ps = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
-        ns = np.array([10, 11])
-
-        with Model() as model:
-            Multinomial("m", n=ns, p=ps)
-
-        assert_almost_equal(
-            sum(multinomial_logpdf(val, n, p) for val, n, p in zip(vals, ns, ps)),
-            model.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_multinomial_vec_2d_p(self):
-        vals = np.array([[2, 4, 4], [3, 3, 4]])
-        ps = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
-        n = 10
-
-        with Model() as model:
-            Multinomial("m", n=n, p=ps)
-
-        assert_almost_equal(
-            sum(multinomial_logpdf(val, n, p) for val, p in zip(vals, ps)),
-            model.fastlogp({"m": vals}),
-            decimal=4,
-        )
-
-    def test_batch_multinomial(self):
-        n = 10
-        vals = intX(np.zeros((4, 5, 3)))
-        p = floatX(np.zeros_like(vals))
-        inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
-        np.put_along_axis(vals, inds, n, axis=-1)
-        np.put_along_axis(p, inds, 1, axis=-1)
-
-        dist = Multinomial.dist(n=n, p=p)
-        logp_mn = at.exp(pm.logp(dist, vals)).eval()
-        assert_almost_equal(
-            logp_mn,
-            np.ones(vals.shape[:-1]),
-            decimal=select_by_precision(float64=6, float32=3),
-        )
-
-        dist = Multinomial.dist(n=n, p=p, size=2)
-        sample = dist.eval()
-        assert_allclose(sample, np.stack([vals, vals], axis=0))
-
     def test_multinomial_zero_probs(self):
         # test multinomial accepts 0 probabilities / observations:
-        value = aesara.shared(np.array([0, 0, 100], dtype=int))
-        logp = pm.Multinomial.logp(value=value, n=100, p=at.constant([0.0, 0.0, 1.0]))
-        logp_fn = aesara.function(inputs=[], outputs=logp)
-        assert logp_fn() >= 0
-
-        value.set_value(np.array([50, 50, 0], dtype=int))
-        assert np.isneginf(logp_fn())
+        mn = pm.Multinomial.dist(n=100, p=[0.0, 0.0, 1.0])
+        assert pm.logp(mn, np.array([0, 0, 100])).eval() >= 0
+        assert pm.logp(mn, np.array([50, 50, 0])).eval() == -np.inf
 
     @pytest.mark.parametrize("n", [2, 3])
     def test_dirichlet_multinomial(self, n):

From 4ed104285c4612a2911ae5372c6ee8431a2e83e6 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 6 Dec 2021 18:39:25 +0100
Subject: [PATCH 4/6] Refactor dirichlet vectorized logp tests

---
 pymc/tests/test_distributions.py | 66 ++++++++++++++++----------------
 1 file changed, 32 insertions(+), 34 deletions(-)

diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index ec67fcf7c7..55403d26ed 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -463,8 +463,12 @@ def discrete_weibull_logpmf(value, q, beta):
     )
 
 
-def dirichlet_logpdf(value, a):
-    return floatX((-betafn(a) + logpow(value, a - 1).sum(-1)).sum())
+def _dirichlet_logpdf(value, a):
+    # scipy.stats.dirichlet.logpdf suffers from numerical precision issues
+    return -betafn(a) + logpow(value, a - 1).sum()
+
+
+dirichlet_logpdf = np.vectorize(_dirichlet_logpdf, signature="(n),(n)->()")
 
 
 def categorical_logpdf(value, p):
@@ -2101,32 +2105,34 @@ def test_lkj(self, x, eta, n, lp):
 
     @pytest.mark.parametrize("n", [1, 2, 3])
     def test_dirichlet(self, n):
-        self.check_logp(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
-
-    @pytest.mark.parametrize("dist_shape", [(1, 2), (2, 4, 3)])
-    def test_dirichlet_with_batch_shapes(self, dist_shape):
-        a = np.ones(dist_shape)
-        with pm.Model() as model:
-            d = pm.Dirichlet("d", a=a)
-
-        # Generate sample points to test
-        d_value = d.tag.value_var
-        d_point = d.eval().astype("float64")
-        d_point /= d_point.sum(axis=-1)[..., None]
-
-        if hasattr(d_value.tag, "transform"):
-            d_point_trans = d_value.tag.transform.forward(
-                at.as_tensor(d_point), *d.owner.inputs
-            ).eval()
-        else:
-            d_point_trans = d_point
+        self.check_logp(
+            Dirichlet,
+            Simplex(n),
+            {"a": Vector(Rplus, n)},
+            dirichlet_logpdf,
+        )
 
-        pymc_res = logpt(d, d_point_trans, jacobian=False, sum=False).eval()
-        scipy_res = np.empty_like(pymc_res)
-        for idx in np.ndindex(a.shape[:-1]):
-            scipy_res[idx] = scipy.stats.dirichlet(a[idx]).logpdf(d_point[idx])
+    @pytest.mark.parametrize(
+        "a",
+        [
+            ([2, 3, 5]),
+            ([[2, 3, 5], [9, 19, 3]]),
+            (np.abs(np.random.randn(2, 2, 4)) + 1),
+        ],
+    )
+    @pytest.mark.parametrize("size", [2, (1, 2), (2, 4, 3)])
+    def test_dirichlet_vectorized(self, a, size):
+        a = floatX(np.array(a))
+
+        dir = pm.Dirichlet.dist(a=a, size=size)
+        vals = dir.eval()
 
-        assert_almost_equal(pymc_res, scipy_res)
+        assert_almost_equal(
+            dirichlet_logpdf(vals, a),
+            pm.logp(dir, vals).eval(),
+            decimal=4,
+            err_msg=f"vals={vals}",
+        )
 
     def test_dirichlet_shape(self):
         a = at.as_tensor_variable(np.r_[1, 2])
@@ -2136,14 +2142,6 @@ def test_dirichlet_shape(self):
         with pytest.warns(DeprecationWarning), aesara.change_flags(compute_test_value="ignore"):
             dir_rv = Dirichlet.dist(at.vector())
 
-    def test_dirichlet_2D(self):
-        self.check_logp(
-            Dirichlet,
-            MultiSimplex(2, 2),
-            {"a": Vector(Vector(Rplus, 2), 2)},
-            dirichlet_logpdf,
-        )
-
     @pytest.mark.parametrize("n", [2, 3])
     def test_multinomial(self, n):
         self.check_logp(

From 82dc0f5123107dcb7e52896248a20e1dc43cdfef Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Wed, 8 Dec 2021 09:48:14 +0100
Subject: [PATCH 5/6] Harmonize Multinomial, Dirichlet and DirichletMultinomial
 docstrings

---
 pymc/distributions/multivariate.py | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index d01662320d..43a66fe1e5 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -406,8 +406,9 @@ class Dirichlet(Continuous):
 
     Parameters
     ----------
-    a: array
-        Concentration parameters (a > 0).
+    a: float array
+        Concentration parameters (a > 0). The number of categories is given by the
+        length of the last axis.
     """
     rv_op = dirichlet
 
@@ -508,10 +509,11 @@ class Multinomial(Discrete):
     Parameters
     ----------
     n: int
-        Number of trials (n > 0)
-    p: vector
-        Probability of each one of the different outcomes. Elements must be non-negative
-        and sum to 1 along the last axis. They will be automatically rescaled otherwise.
+        Total counts in each replicate (n > 0).
+    p: float array
+        Probability of each one of the different outcomes (0 <= p <= 1). The number of
+        categories is given by the length of the last axis. Elements are expected to sum
+        to 1 along the last axis, and they will be automatically rescaled otherwise.
     """
     rv_op = multinomial
 
@@ -625,11 +627,11 @@ class DirichletMultinomial(Discrete):
     Parameters
     ----------
     n : int
-        Total counts in each replicate.
+        Total counts in each replicate (n > 0).
 
-    a : vector
-        Dirichlet alpha parameter. Elements must be strictly positive. The number of
-        categories is given by the length of the last axis.
+    a : float array
+        Dirichlet concentration parameters (a > 0). The number of categories is given by
+        the length of the last axis.
     """
     rv_op = dirichlet_multinomial
 

From 0549f9373eb5a6d6b9f654a6b39a6314e8afd7d6 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 6 Dec 2021 18:40:30 +0100
Subject: [PATCH 6/6] Remove stale/redundant random tests

---
 pymc/tests/test_distributions.py        | 35 -------------------------
 pymc/tests/test_distributions_random.py | 12 ---------
 2 files changed, 47 deletions(-)

diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index 55403d26ed..a4c1ee472d 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -2134,14 +2134,6 @@ def test_dirichlet_vectorized(self, a, size):
             err_msg=f"vals={vals}",
         )
 
-    def test_dirichlet_shape(self):
-        a = at.as_tensor_variable(np.r_[1, 2])
-        dir_rv = Dirichlet.dist(a)
-        assert dir_rv.shape.eval() == (2,)
-
-        with pytest.warns(DeprecationWarning), aesara.change_flags(compute_test_value="ignore"):
-            dir_rv = Dirichlet.dist(at.vector())
-
     @pytest.mark.parametrize("n", [2, 3])
     def test_multinomial(self, n):
         self.check_logp(
@@ -2151,33 +2143,6 @@ def test_multinomial(self, n):
             lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
         )
 
-    @pytest.mark.parametrize(
-        "p, size, n",
-        [
-            [[0.25, 0.25, 0.25, 0.25], (4,), 2],
-            [[0.25, 0.25, 0.25, 0.25], (1, 4), 3],
-            # 3: expect to fail
-            # [[.25, .25, .25, .25], (10, 4)],
-            [[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5],
-            # 5: expect to fail
-            # [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
-            [[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13],
-            [[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]],
-            [
-                [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]],
-                (10, 2, 4),
-                [31, 37],
-            ],
-            [[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]],
-        ],
-    )
-    def test_multinomial_random(self, p, size, n):
-        p = np.asarray(p)
-        with Model() as model:
-            m = Multinomial("m", n=n, p=p, size=size)
-
-        assert m.eval().shape == size + p.shape
-
     @pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
     @pytest.mark.parametrize(
         "p",
diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 425f2ccc5e..0216de75fe 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -255,18 +255,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
     default_shape = (1,)
 
 
-@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
-class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
-    distribution = pm.ZeroInflatedNegativeBinomial
-    params = {"mu": 1.0, "alpha": 1.0, "psi": 0.3}
-
-
-@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
-class TestZeroInflatedBinomial(BaseTestCases.BaseTestCase):
-    distribution = pm.ZeroInflatedBinomial
-    params = {"n": 10, "p": 0.6, "psi": 0.3}
-
-
 class BaseTestDistribution(SeededTest):
     """
     This class provides a base for tests that new RandomVariables are correctly