File tree Expand file tree Collapse file tree 2 files changed +25
-1
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -980,7 +980,14 @@ def local_sum_make_vector(fgraph, node):
980980 elements = array .owner .inputs
981981 acc_dtype = node .op .acc_dtype
982982 out_dtype = node .op .dtype
983- element_sum = cast (add (* [cast (value , acc_dtype ) for value in elements ]), out_dtype )
983+ if len (elements ) == 0 :
984+ element_sum = zeros (dtype = out_dtype , shape = ())
985+ elif len (elements ) == 1 :
986+ element_sum = cast (elements [0 ], out_dtype )
987+ else :
988+ element_sum = cast (
989+ add (* [cast (value , acc_dtype ) for value in elements ]), out_dtype
990+ )
984991
985992 return [element_sum ]
986993
Original file line number Diff line number Diff line change @@ -1321,6 +1321,23 @@ def test_local_sum_make_vector():
13211321 for var in between :
13221322 assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
13231323
1324+ # Check empty MakeVector
1325+ mv = MakeVector (config .floatX )
1326+ output = mv ().sum ()
1327+
1328+ output = rewrite_graph (output )
1329+ between = vars_between ([a , b , c ], [output ])
1330+ for var in between :
1331+ assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
1332+
1333+ mv = MakeVector (config .floatX )
1334+ output = mv (a ).sum ()
1335+
1336+ output = rewrite_graph (output )
1337+ between = vars_between ([a , b , c ], [output ])
1338+ for var in between :
1339+ assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
1340+
13241341
13251342@pytest .mark .parametrize (
13261343 "dtype" ,
You can’t perform that action at this time.
0 commit comments