2
2
import sys
3
3
import warnings
4
4
from contextlib import contextmanager
5
+ from copy import copy
5
6
from functools import singledispatch
6
7
from textwrap import dedent
7
8
from typing import Union
15
16
from numba import types
16
17
from numba .core .errors import TypingError
17
18
from numba .cpython .unsafe .tuple import tuple_setitem # noqa: F401
18
- from numba .extending import box
19
+ from numba .extending import box , overload
19
20
20
21
from pytensor import config
21
22
from pytensor .compile .builders import OpFromGraph
47
48
from pytensor .tensor .type_other import MakeSlice , NoneConst
48
49
49
50
51
+ def global_numba_func (func ):
52
+ """Use to return global numba functions in numba_funcify_*.
53
+
54
+ This allows tests to remove the compilation using mock.
55
+ """
56
+ return func
57
+
58
+
50
59
def numba_njit (* args , ** kwargs ):
51
60
52
61
kwargs = kwargs .copy ()
@@ -573,29 +582,36 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
573
582
return numba_njit (incsubtensor_fn , boundscheck = True )
574
583
575
584
585
+ @numba_njit (boundscheck = True )
586
+ def advancedincsubtensor1_inplace_set (x , vals , idxs ):
587
+ for idx , val in zip (idxs , vals ):
588
+ x [idx ] = val
589
+ return x
590
+
591
+
592
+ @numba_njit (boundscheck = True )
593
+ def advancedincsubtensor1_inplace_inc (x , vals , idxs ):
594
+ for idx , val in zip (idxs , vals ):
595
+ x [idx ] += val
596
+ return x
597
+
598
+
576
599
@numba_funcify .register (AdvancedIncSubtensor1 )
577
600
def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
578
601
inplace = op .inplace
579
602
set_instead_of_inc = op .set_instead_of_inc
580
603
581
604
if set_instead_of_inc :
582
-
583
- @numba_njit (boundscheck = True )
584
- def advancedincsubtensor1_inplace (x , vals , idxs ):
585
- for idx , val in zip (idxs , vals ):
586
- x [idx ] = val
587
- return x
588
-
605
+ advancedincsubtensor1_inplace = global_numba_func (
606
+ advancedincsubtensor1_inplace_set
607
+ )
589
608
else :
590
-
591
- @numba_njit (boundscheck = True )
592
- def advancedincsubtensor1_inplace (x , vals , idxs ):
593
- for idx , val in zip (idxs , vals ):
594
- x [idx ] += val
595
- return x
609
+ advancedincsubtensor1_inplace = global_numba_func (
610
+ advancedincsubtensor1_inplace_inc
611
+ )
596
612
597
613
if inplace :
598
- return advancedincsubtensor1_inplace
614
+ return global_numba_func ( advancedincsubtensor1_inplace )
599
615
else :
600
616
601
617
@numba_njit
@@ -606,51 +622,48 @@ def advancedincsubtensor1(x, vals, idxs):
606
622
return advancedincsubtensor1
607
623
608
624
609
- @ numba_funcify . register ( DeepCopyOp )
610
- def numba_funcify_DeepCopyOp ( op , node , ** kwargs ):
625
+ def deepcopyop ( x ):
626
+ return copy ( x )
611
627
612
- # Scalars are apparently returned as actual Python scalar types and not
613
- # NumPy scalars, so we need two separate Numba functions for each case.
614
628
615
- # The type can also be RandomType with no ndims
616
- if not hasattr (node .outputs [0 ].type , "ndim" ) or node .outputs [0 ].type .ndim == 0 :
617
- # TODO: Do we really need to compile a pass-through function like this?
618
- @numba_njit (inline = "always" )
619
- def deepcopyop (x ):
620
- return x
629
+ @overload (deepcopyop )
630
+ def dispatch_deepcopyop (x ):
631
+ if isinstance (x , types .Array ):
632
+ return lambda x : np .copy (x )
621
633
622
- else :
634
+ return lambda x : x
623
635
624
- @numba_njit (inline = "always" )
625
- def deepcopyop (x ):
626
- return x .copy ()
627
636
637
+ @numba_funcify .register (DeepCopyOp )
638
+ def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
628
639
return deepcopyop
629
640
630
641
642
+ @numba_njit
643
+ def makeslice (* x ):
644
+ return slice (* x )
645
+
646
+
631
647
@numba_funcify .register (MakeSlice )
632
648
def numba_funcify_MakeSlice (op , ** kwargs ):
633
- @numba_njit
634
- def makeslice (* x ):
635
- return slice (* x )
649
+ return global_numba_func (makeslice )
636
650
637
- return makeslice
651
+
652
+ @numba_njit
653
+ def shape (x ):
654
+ return np .asarray (np .shape (x ))
638
655
639
656
640
657
@numba_funcify .register (Shape )
641
658
def numba_funcify_Shape (op , ** kwargs ):
642
- @numba_njit (inline = "always" )
643
- def shape (x ):
644
- return np .asarray (np .shape (x ))
645
-
646
- return shape
659
+ return global_numba_func (shape )
647
660
648
661
649
662
@numba_funcify .register (Shape_i )
650
663
def numba_funcify_Shape_i (op , ** kwargs ):
651
664
i = op .i
652
665
653
- @numba_njit ( inline = "always" )
666
+ @numba_njit
654
667
def shape_i (x ):
655
668
return np .shape (x )[i ]
656
669
@@ -683,13 +696,13 @@ def numba_funcify_Reshape(op, **kwargs):
683
696
684
697
if ndim == 0 :
685
698
686
- @numba_njit ( inline = "always" )
699
+ @numba_njit
687
700
def reshape (x , shape ):
688
701
return x .item ()
689
702
690
703
else :
691
704
692
- @numba_njit ( inline = "always" )
705
+ @numba_njit
693
706
def reshape (x , shape ):
694
707
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
695
708
return np .reshape (
@@ -732,15 +745,15 @@ def int_to_float_fn(inputs, out_dtype):
732
745
733
746
args_dtype = np .dtype (f"f{ out_dtype .itemsize } " )
734
747
735
- @numba_njit ( inline = "always" )
748
+ @numba_njit
736
749
def inputs_cast (x ):
737
750
return x .astype (args_dtype )
738
751
739
752
else :
740
753
args_dtype_sz = max (_arg .type .numpy_dtype .itemsize for _arg in inputs )
741
754
args_dtype = np .dtype (f"f{ args_dtype_sz } " )
742
755
743
- @numba_njit ( inline = "always" )
756
+ @numba_njit
744
757
def inputs_cast (x ):
745
758
return x .astype (args_dtype )
746
759
@@ -755,7 +768,7 @@ def numba_funcify_Dot(op, node, **kwargs):
755
768
out_dtype = node .outputs [0 ].type .numpy_dtype
756
769
inputs_cast = int_to_float_fn (node .inputs , out_dtype )
757
770
758
- @numba_njit ( inline = "always" )
771
+ @numba_njit
759
772
def dot (x , y ):
760
773
return np .asarray (np .dot (inputs_cast (x ), inputs_cast (y ))).astype (out_dtype )
761
774
@@ -770,13 +783,14 @@ def numba_funcify_Softplus(op, node, **kwargs):
770
783
@numba_njit
771
784
def softplus (x ):
772
785
if x < - 37.0 :
773
- return direct_cast ( np .exp (x ), x_dtype )
786
+ value = np .exp (x )
774
787
elif x < 18.0 :
775
- return direct_cast ( np .log1p (np .exp (x )), x_dtype )
788
+ value = np .log1p (np .exp (x ))
776
789
elif x < 33.3 :
777
- return direct_cast ( x + np .exp (- x ), x_dtype )
790
+ value = x + np .exp (- x )
778
791
else :
779
- return direct_cast (x , x_dtype )
792
+ value = x
793
+ return direct_cast (value , x_dtype )
780
794
781
795
return softplus
782
796
@@ -791,7 +805,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
791
805
792
806
inputs_cast = int_to_float_fn (node .inputs , out_dtype )
793
807
794
- @numba_njit ( inline = "always" )
808
+ @numba_njit
795
809
def cholesky (a ):
796
810
return np .linalg .cholesky (inputs_cast (a )).astype (out_dtype )
797
811
@@ -852,7 +866,7 @@ def solve(a, b):
852
866
out_dtype = node .outputs [0 ].type .numpy_dtype
853
867
inputs_cast = int_to_float_fn (node .inputs , out_dtype )
854
868
855
- @numba_njit ( inline = "always" )
869
+ @numba_njit
856
870
def solve (a , b ):
857
871
return np .linalg .solve (
858
872
inputs_cast (a ),
0 commit comments