Skip to content

Commit 8cc489b

Browse files
committed
Use scalar variables on Numba Elemwise dispatch
1 parent 8267d0e commit 8cc489b

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,19 @@
3030
OR,
3131
XOR,
3232
Add,
33-
Composite,
3433
IntDiv,
3534
Mul,
3635
ScalarMaximum,
3736
ScalarMinimum,
3837
Sub,
3938
TrueDiv,
39+
get_scalar_type,
4040
scalar_maximum,
4141
)
4242
from pytensor.scalar.basic import add as add_as
4343
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4444
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
4545
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
46-
from pytensor.tensor.type import scalar
4746

4847

4948
@singledispatch
@@ -348,13 +347,8 @@ def axis_apply_fn(x):
348347

349348
@numba_funcify.register(Elemwise)
350349
def numba_funcify_Elemwise(op, node, **kwargs):
351-
# Creating a new scalar node is more involved and unnecessary
352-
# if the scalar_op is composite, as the fgraph already contains
353-
# all the necessary information.
354-
scalar_node = None
355-
if not isinstance(op.scalar_op, Composite):
356-
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
357-
scalar_node = op.scalar_op.make_node(*scalar_inputs)
350+
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
351+
scalar_node = op.scalar_op.make_node(*scalar_inputs)
358352

359353
scalar_op_fn = numba_funcify(
360354
op.scalar_op,

tests/link/numba/test_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ def assert_fn(x, y):
267267
x, y
268268
)
269269

270-
if isinstance(fgraph, tuple):
271-
fn_inputs, fn_outputs = fgraph
272-
else:
270+
if isinstance(fgraph, FunctionGraph):
273271
fn_inputs = fgraph.inputs
274272
fn_outputs = fgraph.outputs
273+
else:
274+
fn_inputs, fn_outputs = fgraph
275275

276276
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
277277

tests/link/numba/test_elemwise.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from pytensor.gradient import grad
1616
from pytensor.graph.basic import Constant
1717
from pytensor.graph.fg import FunctionGraph
18-
from pytensor.tensor.elemwise import CAReduce, DimShuffle
18+
from pytensor.scalar import float64
19+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1920
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
2021
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
2122
from tests.link.numba.test_basic import (
@@ -691,3 +692,17 @@ def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
691692
return careduce_benchmark_tester(
692693
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
693694
)
695+
696+
697+
def test_scalar_loop():
698+
a = float64("a")
699+
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
700+
701+
x = pt.tensor("x", shape=(3,))
702+
elemwise_loop = Elemwise(scalar_loop)(3, x)
703+
704+
with pytest.warns(UserWarning, match="object mode"):
705+
compare_numba_and_py(
706+
([x], [elemwise_loop]),
707+
(np.array([1, 2, 3], dtype="float64"),),
708+
)

0 commit comments

Comments
 (0)