diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py
index a599bebea1..10c2bb14ad 100644
--- a/pymc/distributions/mixture.py
+++ b/pymc/distributions/mixture.py
@@ -524,10 +524,6 @@ class NormalMixture:
         the component standard deviations
     tau : tensor_like of float
         the component precisions
-    comp_shape : shape of the Normal component
-        notice that it should be different than the shape
-        of the mixture distribution, with the last axis representing
-        the number of components.
 
     Notes
     -----
@@ -554,16 +550,16 @@ class NormalMixture:
             y = pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data)
     """
 
-    def __new__(cls, name, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
+    def __new__(cls, name, w, mu, sigma=None, tau=None, **kwargs):
         _, sigma = get_tau_sigma(tau=tau, sigma=sigma)
 
-        return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
+        return Mixture(name, w, Normal.dist(mu, sigma=sigma), **kwargs)
 
     @classmethod
-    def dist(cls, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
+    def dist(cls, w, mu, sigma=None, tau=None, **kwargs):
         _, sigma = get_tau_sigma(tau=tau, sigma=sigma)
 
-        return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
+        return Mixture.dist(w, Normal.dist(mu, sigma=sigma), **kwargs)
 
 
 def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py
index d07a7be927..7ce6084d8d 100644
--- a/tests/distributions/test_mixture.py
+++ b/tests/distributions/test_mixture.py
@@ -820,10 +820,8 @@ def test_normal_mixture_nd(self, seeded_test, nd, ncomp):
             mus = Normal("mus", shape=comp_shape)
             taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape)
             ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,))
-            mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape)
-            obs0 = NormalMixture(
-                "obs", w=ws, mu=mus, tau=taus, comp_shape=comp_shape, observed=observed
-            )
+            mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
+            obs0 = NormalMixture("obs", w=ws, mu=mus, tau=taus, observed=observed)
 
         with Model() as model1:
             mus = Normal("mus", shape=comp_shape)
@@ -867,7 +865,6 @@ def ref_rand(size, w, mu, sigma):
                 "mu": Domain([[0.05, 2.5], [-5.0, 1.0]], edges=(None, None)),
                 "sigma": Domain([[1, 1], [1.5, 2.0]], edges=(None, None)),
             },
-            extra_args={"comp_shape": 2},
             size=1000,
             ref_rand=ref_rand,
         )
@@ -878,7 +875,6 @@ def ref_rand(size, w, mu, sigma):
                 "mu": Domain([[-5.0, 1.0, 2.5]], edges=(None, None)),
                 "sigma": Domain([[1.5, 2.0, 3.0]], edges=(None, None)),
             },
-            extra_args={"comp_shape": 3},
             size=1000,
             ref_rand=ref_rand,
         )
@@ -902,7 +898,6 @@ def test_scalar_components(self):
                 w=np.ones(npop) / npop,
                 mu=mus,
                 sigma=1e-5,
-                comp_shape=(nd, npop),
                 shape=nd,
             )
             z = Categorical("z", p=np.ones(npop) / npop, shape=nd)