33from numba .types import Array , Boolean , List , Number
44
55import pytensor .link .numba .dispatch .basic as numba_basic
6- from pytensor .link .numba .dispatch .basic import register_funcify_default_op_cache_key
6+ from pytensor .link .numba .dispatch .basic import (
7+ default_hash_key_from_props ,
8+ register_funcify_and_cache_key ,
9+ register_funcify_default_op_cache_key ,
10+ )
711from pytensor .link .numba .dispatch .compile_ops import numba_deepcopy
812from pytensor .tensor .type_other import SliceType
913from pytensor .typed_list import (
@@ -48,12 +52,14 @@ def all_equal(x, y):
4852 return False
4953 return True
5054
51- if isinstance (x , Array ) and isinstance (y , Array ):
52-
55+ if ( isinstance (x , Array ) and x . ndim > 0 ) and ( isinstance (y , Array ) and y . ndim > 0 ):
56+ # (x == y).all() fails for 0d arrays
5357 def all_equal (x , y ):
5458 return (x == y ).all ()
5559
56- if isinstance (x , Number | Boolean ) and isinstance (y , Number | Boolean ):
60+ if (isinstance (x , Number | Boolean ) or (isinstance (x , Array ) and x .ndim == 0 )) and (
61+ isinstance (y , Number | Boolean ) or (isinstance (y , Array ) and y .ndim == 0 )
62+ ):
5763
5864 def all_equal (x , y ):
5965 return x == y
@@ -71,6 +77,16 @@ def deepcopy_list(x):
7177 return deepcopy_list
7278
7379
80+ def cache_key_if_not_inplace (op , inplace : bool ):
81+ if inplace :
82+ # NUMBA is misbehaving with wrapped inplace ListType operations
83+ # which happens when we cache it in PyTensor
84+ # https://github.com/numba/numba/issues/10356
85+ return None
86+ else :
87+ return default_hash_key_from_props (op )
88+
89+
7490@register_funcify_default_op_cache_key (MakeList )
7591def numba_funcify_make_list (op , node , ** kwargs ):
7692 @numba_basic .numba_njit
@@ -108,7 +124,7 @@ def list_get_item_index(x, index):
108124 return list_get_item_index
109125
110126
111- @register_funcify_default_op_cache_key (Reverse )
127+ @register_funcify_and_cache_key (Reverse )
112128def numba_funcify_list_reverse (op , node , ** kwargs ):
113129 inplace = op .inplace
114130
@@ -121,10 +137,10 @@ def list_reverse(x):
121137 z .reverse ()
122138 return z
123139
124- return list_reverse
140+ return list_reverse , cache_key_if_not_inplace ( op , inplace )
125141
126142
127- @register_funcify_default_op_cache_key (Append )
143+ @register_funcify_and_cache_key (Append )
128144def numba_funcify_list_append (op , node , ** kwargs ):
129145 inplace = op .inplace
130146
@@ -137,10 +153,10 @@ def list_append(x, to_append):
137153 z .append (numba_deepcopy (to_append ))
138154 return z
139155
140- return list_append
156+ return list_append , cache_key_if_not_inplace ( op , inplace )
141157
142158
143- @register_funcify_default_op_cache_key (Extend )
159+ @register_funcify_and_cache_key (Extend )
144160def numba_funcify_list_extend (op , node , ** kwargs ):
145161 inplace = op .inplace
146162
@@ -153,10 +169,10 @@ def list_extend(x, to_append):
153169 z .extend (numba_deepcopy (to_append ))
154170 return z
155171
156- return list_extend
172+ return list_extend , cache_key_if_not_inplace ( op , inplace )
157173
158174
159- @register_funcify_default_op_cache_key (Insert )
175+ @register_funcify_and_cache_key (Insert )
160176def numba_funcify_list_insert (op , node , ** kwargs ):
161177 inplace = op .inplace
162178
@@ -169,7 +185,7 @@ def list_insert(x, index, to_insert):
169185 z .insert (index .item (), numba_deepcopy (to_insert ))
170186 return z
171187
172- return list_insert
188+ return list_insert , cache_key_if_not_inplace ( op , inplace )
173189
174190
175191@register_funcify_default_op_cache_key (Index )
@@ -197,7 +213,7 @@ def list_count(x, elem):
197213 return list_count
198214
199215
200- @register_funcify_default_op_cache_key (Remove )
216+ @register_funcify_and_cache_key (Remove )
201217def numba_funcify_list_remove (op , node , ** kwargs ):
202218 inplace = op .inplace
203219
@@ -217,4 +233,4 @@ def list_remove(x, to_remove):
217233 z .pop (index_to_remove )
218234 return z
219235
220- return list_remove
236+ return list_remove , cache_key_if_not_inplace ( op , inplace )
0 commit comments