From 717ba1a73e0a3f5c043e8c8e65af9b05e11390e8 Mon Sep 17 00:00:00 2001
From: Adrian Seyboldt <adrian.seyboldt@gmail.com>
Date: Thu, 15 Jun 2023 21:56:00 -0500
Subject: [PATCH 1/3] Add rewrite for Sum(MakeVector)

---
 pytensor/tensor/rewriting/basic.py   | 25 +++++++++++++++++++++++++
 tests/tensor/rewriting/test_basic.py | 16 ++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 58a5918c12..96796a4dc0 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -43,6 +43,7 @@
 from pytensor.tensor.elemwise import DimShuffle, Elemwise
 from pytensor.tensor.exceptions import NotScalarConstantError
 from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
+from pytensor.tensor.math import Sum, add
 from pytensor.tensor.math import all as at_all
 from pytensor.tensor.math import eq
 from pytensor.tensor.shape import Shape_i
@@ -956,6 +957,30 @@ def local_join_make_vector(fgraph, node):
         return [ret]
 
 
+@register_specialize
+@register_canonicalize
+@register_useless
+@node_rewriter([Sum])
+def local_sum_make_vector(fgraph, node):
+    """A sum of a MakeVector node is just the sum of the elements."""
+    (array,) = node.inputs
+
+    if array.owner is None:
+        return
+
+    if not isinstance(array.owner.op, MakeVector):
+        return
+
+    if node.op.axis not in [None, 0, -1]:
+        return
+
+    elements = array.owner.inputs
+    dtype = node.op.acc_dtype
+    element_sum = add(*[cast(value, dtype) for value in elements])
+
+    return [as_tensor_variable(element_sum)]
+
+
 @register_useless("local_remove_switch_const_cond")
 @register_canonicalize("fast_compile", "local_remove_switch_const_cond")
 @register_specialize
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index fe2b795907..c6b0ee8c69 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -19,6 +19,7 @@
 from pytensor.graph.rewriting.utils import rewrite_graph
 from pytensor.printing import debugprint, pprint
 from pytensor.raise_op import Assert, CheckAndRaise
+from pytensor.scalar.basic import Add
 from pytensor.tensor.basic import (
     Alloc,
     Join,
@@ -102,6 +103,7 @@
     values_eq_approx_remove_nan,
     vector,
 )
+from pytensor.tensor.var import TensorVariable
 from tests import unittest_tools as utt
 
 
@@ -1300,6 +1302,20 @@ def test_local_join_make_vector():
     assert check_stack_trace(f, ops_to_check="all")
 
 
+def test_local_sum_make_vector():
+    a, b, c = scalars("abc")
+    mv = MakeVector(config.floatX)
+    output = mv(a, b, c).sum()
+
+    func = function([a, b, c], output)
+
+    elemwise = func.maker.fgraph.outputs[0].owner
+    # The MakeVector op should be optimized away, so we just
+    # take the sum of the scalars.
+    assert elemwise.inputs[0].name == "a"
+    assert isinstance(elemwise.inputs[0], TensorVariable)
+
+
 @pytest.mark.parametrize(
     "dtype",
     [

From 287087478c566ad761e54fea8b694c70e15ccb40 Mon Sep 17 00:00:00 2001
From: Adrian Seyboldt <adrian.seyboldt@gmail.com>
Date: Fri, 16 Jun 2023 11:49:46 -0500
Subject: [PATCH 2/3] Improve test_local_sum_make_vector rewrite

---
 pytensor/tensor/rewriting/basic.py   | 14 +++++++++-----
 tests/tensor/rewriting/test_basic.py | 24 +++++++++++++++---------
 2 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 96796a4dc0..8cd40469db 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -971,14 +971,18 @@ def local_sum_make_vector(fgraph, node):
     if not isinstance(array.owner.op, MakeVector):
         return
 
-    if node.op.axis not in [None, 0, -1]:
-        return
+    if node.op.axis == ():
+        return [array]
+
+    # If this is not the case the sum is invalid
+    assert node.op.axis is None or node.op.axis == (0,)
 
     elements = array.owner.inputs
-    dtype = node.op.acc_dtype
-    element_sum = add(*[cast(value, dtype) for value in elements])
+    acc_dtype = node.op.acc_dtype
+    out_dtype = node.op.dtype
+    element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype)
 
-    return [as_tensor_variable(element_sum)]
+    return [element_sum]
 
 
 @register_useless("local_remove_switch_const_cond")
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index c6b0ee8c69..9b79db5f28 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -12,14 +12,13 @@
 from pytensor.compile.mode import get_default_mode, get_mode
 from pytensor.compile.ops import DeepCopyOp, deep_copy_op
 from pytensor.configdefaults import config
-from pytensor.graph.basic import equal_computations
+from pytensor.graph.basic import equal_computations, vars_between
 from pytensor.graph.fg import FunctionGraph
 from pytensor.graph.rewriting.basic import check_stack_trace, out2in
 from pytensor.graph.rewriting.db import RewriteDatabaseQuery
 from pytensor.graph.rewriting.utils import rewrite_graph
 from pytensor.printing import debugprint, pprint
 from pytensor.raise_op import Assert, CheckAndRaise
-from pytensor.scalar.basic import Add
 from pytensor.tensor.basic import (
     Alloc,
     Join,
@@ -32,6 +31,7 @@
 )
 from pytensor.tensor.elemwise import DimShuffle, Elemwise
 from pytensor.tensor.math import (
+    Sum,
     add,
     bitwise_and,
     bitwise_or,
@@ -103,7 +103,6 @@
     values_eq_approx_remove_nan,
     vector,
 )
-from pytensor.tensor.var import TensorVariable
 from tests import unittest_tools as utt
 
 
@@ -1307,13 +1306,20 @@ def test_local_sum_make_vector():
     mv = MakeVector(config.floatX)
     output = mv(a, b, c).sum()
 
-    func = function([a, b, c], output)
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
 
-    elemwise = func.maker.fgraph.outputs[0].owner
-    # The MakeVector op should be optimized away, so we just
-    # take the sum of the scalars.
-    assert elemwise.inputs[0].name == "a"
-    assert isinstance(elemwise.inputs[0], TensorVariable)
+    # Check for empty sum
+    a, b, c = scalars("abc")
+    mv = MakeVector(config.floatX)
+    output = mv(a, b, c).sum(axis=[])
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
 
 
 @pytest.mark.parametrize(

From 7429e945d4eaa5761fa4792079b4dee583197380 Mon Sep 17 00:00:00 2001
From: Adrian Seyboldt <adrian.seyboldt@gmail.com>
Date: Tue, 11 Jul 2023 20:03:31 -0500
Subject: [PATCH 3/3] fix(rewrite): Handle sum of empty make vector

---
 pytensor/tensor/rewriting/basic.py   | 11 +++++++++--
 tests/tensor/rewriting/test_basic.py | 17 +++++++++++++++++
 2 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 8cd40469db..ff103e9fc1 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -975,12 +975,19 @@ def local_sum_make_vector(fgraph, node):
         return [array]
 
     # If this is not the case the sum is invalid
-    assert node.op.axis is None or node.op.axis == (0,)
+    assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,)
 
     elements = array.owner.inputs
     acc_dtype = node.op.acc_dtype
     out_dtype = node.op.dtype
-    element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype)
+    if len(elements) == 0:
+        element_sum = zeros(dtype=out_dtype, shape=())
+    elif len(elements) == 1:
+        element_sum = cast(elements[0], out_dtype)
+    else:
+        element_sum = cast(
+            add(*[cast(value, acc_dtype) for value in elements]), out_dtype
+        )
 
     return [element_sum]
 
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index 9b79db5f28..3c3f917bc9 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -1321,6 +1321,23 @@ def test_local_sum_make_vector():
     for var in between:
         assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
 
+    # Check empty MakeVector
+    mv = MakeVector(config.floatX)
+    output = mv().sum()
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
+
+    mv = MakeVector(config.floatX)
+    output = mv(a).sum()
+
+    output = rewrite_graph(output)
+    between = vars_between([a, b, c], [output])
+    for var in between:
+        assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
+
 
 @pytest.mark.parametrize(
     "dtype",