11import numba
22import numpy as np
3+ from numba .types import Array , Boolean , List , Number
34
45import pytensor .link .numba .dispatch .basic as numba_basic
56from pytensor .link .numba .dispatch .basic import register_funcify_default_op_cache_key
@@ -37,7 +38,7 @@ def numba_all_equal(x, y):
3738def list_all_equal (x , y ):
3839 all_equal = None
3940
40- if isinstance (x , numba . types . List ) and isinstance (y , numba . types . List ):
41+ if isinstance (x , List ) and isinstance (y , List ):
4142
4243 def all_equal (x , y ):
4344 if len (x ) != len (y ):
@@ -47,12 +48,12 @@ def all_equal(x, y):
4748 return False
4849 return True
4950
50- if isinstance (x , numba . types . Array ) and isinstance (y , numba . types . Array ):
51+ if isinstance (x , Array ) and isinstance (y , Array ):
5152
5253 def all_equal (x , y ):
5354 return (x == y ).all ()
5455
55- if isinstance (x , numba . types . Number ) and isinstance (y . numba . types . Number ):
56+ if isinstance (x , Number | Boolean ) and isinstance (y , Number | Boolean ):
5657
5758 def all_equal (x , y ):
5859 return x == y
@@ -62,7 +63,7 @@ def all_equal(x, y):
6263
6364@numba .extending .overload (numba_deepcopy )
6465def numba_deepcopy_list (x ):
65- if isinstance (x , numba . types . List ):
66+ if isinstance (x , List ):
6667
6768 def deepcopy_list (x ):
6869 return [numba_deepcopy (xi ) for xi in x ]
0 commit comments