diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 58a5918c12..ff103e9fc1 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -43,6 +43,7 @@
 from pytensor.tensor.elemwise import DimShuffle, Elemwise
 from pytensor.tensor.exceptions import NotScalarConstantError
 from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
+from pytensor.tensor.math import Sum, add
 from pytensor.tensor.math import all as at_all
 from pytensor.tensor.math import eq
 from pytensor.tensor.shape import Shape_i
@@ -956,6 +957,41 @@ def local_join_make_vector(fgraph, node):
         return [ret]
 
 
+@register_specialize
+@register_canonicalize
+@register_useless
+@node_rewriter([Sum])
+def local_sum_make_vector(fgraph, node):
+    """A sum of a MakeVector node is just the sum of the elements."""
+    (array,) = node.inputs
+
+    if array.owner is None:
+        return
+
+    if not isinstance(array.owner.op, MakeVector):
+        return
+
+    if node.op.axis == ():
+        return [array]
+
+    # If this is not the case the sum is invalid
+    assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,)
+
+    elements = array.owner.inputs
+    acc_dtype = node.op.acc_dtype
+    out_dtype = node.op.dtype
+    if len(elements) == 0:
+        element_sum = zeros(dtype=out_dtype, shape=())
+    elif len(elements) == 1:
+        element_sum = cast(elements[0], out_dtype)
+    else:
+        element_sum = cast(
+            add(*[cast(value, acc_dtype) for value in elements]), out_dtype
+        )
+
+    return [element_sum]
+
+
 @register_useless("local_remove_switch_const_cond")
 @register_canonicalize("fast_compile", "local_remove_switch_const_cond")
 @register_specialize
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index fe2b795907..3c3f917bc9 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -12,7 +12,7 @@
 from pytensor.compile.mode import get_default_mode, get_mode
 from pytensor.compile.ops import DeepCopyOp, deep_copy_op
 from pytensor.configdefaults import config
-from pytensor.graph.basic import equal_computations
+from pytensor.graph.basic import equal_computations, vars_between
 from pytensor.graph.fg import FunctionGraph
 from pytensor.graph.rewriting.basic import check_stack_trace, out2in
 from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -31,6 +31,7 @@
 )
 from pytensor.tensor.elemwise import DimShuffle, Elemwise
 from pytensor.tensor.math import (
+    Sum,
     add,
     bitwise_and,
     bitwise_or,
@@ -1300,6 +1301,44 @@ def test_local_join_make_vector():
     assert check_stack_trace(f, ops_to_check="all")
 
 
+def test_local_sum_make_vector():
+    a, b, c = scalars("abc")
+    mv = MakeVector(config.floatX)
+    output = mv(a, b, c).sum()
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
+
+    # Check for empty sum
+    a, b, c = scalars("abc")
+    mv = MakeVector(config.floatX)
+    output = mv(a, b, c).sum(axis=[])
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
+
+    # Check empty MakeVector
+    mv = MakeVector(config.floatX)
+    output = mv().sum()
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
+
+    mv = MakeVector(config.floatX)
+    output = mv(a).sum()
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
+
+
 @pytest.mark.parametrize(
     "dtype",
     [