Skip to content

Commit afb4885

Browse files
committed
Use some global njit functions in numba
This allows numba to reuse previous typing and compilation results if the same function is reused, which then also leads to smaller llvm modules. For the tests to continue to work we have to return those global functions through a wrapper (`basic.global_numba_func`) so that the tests are still able to disable compilation. Also remove some inline="always" arguments that don't seem to be helpful.
1 parent f4de2fd commit afb4885

File tree

3 files changed

+152
-123
lines changed

3 files changed

+152
-123
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import warnings
44
from contextlib import contextmanager
5+
from copy import copy
56
from functools import singledispatch
67
from textwrap import dedent
78
from typing import Union
@@ -15,7 +16,7 @@
1516
from numba import types
1617
from numba.core.errors import TypingError
1718
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
18-
from numba.extending import box
19+
from numba.extending import box, overload
1920

2021
from pytensor import config
2122
from pytensor.compile.builders import OpFromGraph
@@ -47,6 +48,14 @@
4748
from pytensor.tensor.type_other import MakeSlice, NoneConst
4849

4950

51+
def global_numba_func(func):
52+
"""Use to return global numba functions in numba_funcify_*.
53+
54+
This allows tests to remove the compilation using mock.
55+
"""
56+
return func
57+
58+
5059
def numba_njit(*args, **kwargs):
5160

5261
kwargs = kwargs.copy()
@@ -573,29 +582,36 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
573582
return numba_njit(incsubtensor_fn, boundscheck=True)
574583

575584

585+
@numba_njit(boundscheck=True)
586+
def advancedincsubtensor1_inplace_set(x, vals, idxs):
587+
for idx, val in zip(idxs, vals):
588+
x[idx] = val
589+
return x
590+
591+
592+
@numba_njit(boundscheck=True)
593+
def advancedincsubtensor1_inplace_inc(x, vals, idxs):
594+
for idx, val in zip(idxs, vals):
595+
x[idx] += val
596+
return x
597+
598+
576599
@numba_funcify.register(AdvancedIncSubtensor1)
577600
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
578601
inplace = op.inplace
579602
set_instead_of_inc = op.set_instead_of_inc
580603

581604
if set_instead_of_inc:
582-
583-
@numba_njit(boundscheck=True)
584-
def advancedincsubtensor1_inplace(x, vals, idxs):
585-
for idx, val in zip(idxs, vals):
586-
x[idx] = val
587-
return x
588-
605+
advancedincsubtensor1_inplace = global_numba_func(
606+
advancedincsubtensor1_inplace_set
607+
)
589608
else:
590-
591-
@numba_njit(boundscheck=True)
592-
def advancedincsubtensor1_inplace(x, vals, idxs):
593-
for idx, val in zip(idxs, vals):
594-
x[idx] += val
595-
return x
609+
advancedincsubtensor1_inplace = global_numba_func(
610+
advancedincsubtensor1_inplace_inc
611+
)
596612

597613
if inplace:
598-
return advancedincsubtensor1_inplace
614+
return global_numba_func(advancedincsubtensor1_inplace)
599615
else:
600616

601617
@numba_njit
@@ -606,51 +622,48 @@ def advancedincsubtensor1(x, vals, idxs):
606622
return advancedincsubtensor1
607623

608624

609-
@numba_funcify.register(DeepCopyOp)
610-
def numba_funcify_DeepCopyOp(op, node, **kwargs):
625+
def deepcopyop(x):
626+
return copy(x)
611627

612-
# Scalars are apparently returned as actual Python scalar types and not
613-
# NumPy scalars, so we need two separate Numba functions for each case.
614628

615-
# The type can also be RandomType with no ndims
616-
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0:
617-
# TODO: Do we really need to compile a pass-through function like this?
618-
@numba_njit(inline="always")
619-
def deepcopyop(x):
620-
return x
629+
@overload(deepcopyop)
630+
def dispatch_deepcopyop(x):
631+
if isinstance(x, types.Array):
632+
return lambda x: np.copy(x)
621633

622-
else:
634+
return lambda x: x
623635

624-
@numba_njit(inline="always")
625-
def deepcopyop(x):
626-
return x.copy()
627636

637+
@numba_funcify.register(DeepCopyOp)
638+
def numba_funcify_DeepCopyOp(op, node, **kwargs):
628639
return deepcopyop
629640

630641

642+
@numba_njit
643+
def makeslice(*x):
644+
return slice(*x)
645+
646+
631647
@numba_funcify.register(MakeSlice)
632648
def numba_funcify_MakeSlice(op, **kwargs):
633-
@numba_njit
634-
def makeslice(*x):
635-
return slice(*x)
649+
return global_numba_func(makeslice)
636650

637-
return makeslice
651+
652+
@numba_njit
653+
def shape(x):
654+
return np.asarray(np.shape(x))
638655

639656

640657
@numba_funcify.register(Shape)
641658
def numba_funcify_Shape(op, **kwargs):
642-
@numba_njit(inline="always")
643-
def shape(x):
644-
return np.asarray(np.shape(x))
645-
646-
return shape
659+
return global_numba_func(shape)
647660

648661

649662
@numba_funcify.register(Shape_i)
650663
def numba_funcify_Shape_i(op, **kwargs):
651664
i = op.i
652665

653-
@numba_njit(inline="always")
666+
@numba_njit
654667
def shape_i(x):
655668
return np.shape(x)[i]
656669

@@ -683,13 +696,13 @@ def numba_funcify_Reshape(op, **kwargs):
683696

684697
if ndim == 0:
685698

686-
@numba_njit(inline="always")
699+
@numba_njit
687700
def reshape(x, shape):
688701
return x.item()
689702

690703
else:
691704

692-
@numba_njit(inline="always")
705+
@numba_njit
693706
def reshape(x, shape):
694707
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
695708
return np.reshape(
@@ -732,15 +745,15 @@ def int_to_float_fn(inputs, out_dtype):
732745

733746
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
734747

735-
@numba_njit(inline="always")
748+
@numba_njit
736749
def inputs_cast(x):
737750
return x.astype(args_dtype)
738751

739752
else:
740753
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
741754
args_dtype = np.dtype(f"f{args_dtype_sz}")
742755

743-
@numba_njit(inline="always")
756+
@numba_njit
744757
def inputs_cast(x):
745758
return x.astype(args_dtype)
746759

@@ -755,7 +768,7 @@ def numba_funcify_Dot(op, node, **kwargs):
755768
out_dtype = node.outputs[0].type.numpy_dtype
756769
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
757770

758-
@numba_njit(inline="always")
771+
@numba_njit
759772
def dot(x, y):
760773
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
761774

@@ -770,13 +783,14 @@ def numba_funcify_Softplus(op, node, **kwargs):
770783
@numba_njit
771784
def softplus(x):
772785
if x < -37.0:
773-
return direct_cast(np.exp(x), x_dtype)
786+
value = np.exp(x)
774787
elif x < 18.0:
775-
return direct_cast(np.log1p(np.exp(x)), x_dtype)
788+
value = np.log1p(np.exp(x))
776789
elif x < 33.3:
777-
return direct_cast(x + np.exp(-x), x_dtype)
790+
value = x + np.exp(-x)
778791
else:
779-
return direct_cast(x, x_dtype)
792+
value = x
793+
return direct_cast(value, x_dtype)
780794

781795
return softplus
782796

@@ -791,7 +805,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
791805

792806
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
793807

794-
@numba_njit(inline="always")
808+
@numba_njit
795809
def cholesky(a):
796810
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
797811

@@ -852,7 +866,7 @@ def solve(a, b):
852866
out_dtype = node.outputs[0].type.numpy_dtype
853867
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
854868

855-
@numba_njit(inline="always")
869+
@numba_njit
856870
def solve(a, b):
857871
return np.linalg.solve(
858872
inputs_cast(a),

0 commit comments

Comments
 (0)