Skip to content

Commit 04b0137

Browse files
committed
Numba list Ops: do not cache when inplace
numba/numba#10356
1 parent 3cfd6e4 commit 04b0137

File tree

3 files changed

+94
-42
lines changed

3 files changed

+94
-42
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,36 @@ def dispatch_func_wrapper(*args, **kwargs):
368368
return decorator
369369

370370

371+
def default_hash_key_from_props(op, **extra_fields):
372+
props_dict = op._props_dict()
373+
if not props_dict:
374+
# Simple op, just use the type string as key
375+
hash = sha256(
376+
f"({type(op)}, {tuple(extra_fields.items())})".encode()
377+
).hexdigest()
378+
else:
379+
# Simple props, can use string representation of props as key
380+
simple_types = (str, bool, int, type(None), float)
381+
container_types = (tuple, frozenset)
382+
if all(
383+
isinstance(v, simple_types)
384+
or (
385+
isinstance(v, container_types)
386+
and all(isinstance(i, simple_types) for i in v)
387+
)
388+
for v in props_dict.values()
389+
):
390+
hash = sha256(
391+
f"({type(op)}, {tuple(props_dict.items())}, {tuple(extra_fields.items())})".encode()
392+
).hexdigest()
393+
else:
394+
# Complex props, use pickle to serialize them
395+
hash = hash_from_pickle_dump(
396+
(str(type(op)), tuple(props_dict.items()), tuple(extra_fields.items())),
397+
)
398+
return hash
399+
400+
371401
@singledispatch
372402
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
373403
"""Funcify an Op and return a unique cache key that can be used by numba caching.
@@ -411,36 +441,12 @@ def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str
411441
else:
412442
func, integer_str = func_and_int, "None"
413443

414-
try:
415-
props_dict = op._props_dict()
416-
except AttributeError:
444+
if not hasattr(op, "__props__"):
417445
raise ValueError(
418446
"The function wrapped by `numba_funcify_default_op_cache_key` can only be used with Ops with `_props`, "
419447
f"but {op} of type {type(op)} has no _props defined (not even empty)."
420448
)
421-
if not props_dict:
422-
# Simple op, just use the type string as key
423-
hash = sha256(f"({type(op)}, {integer_str})".encode()).hexdigest()
424-
else:
425-
# Simple props, can use string representation of props as key
426-
simple_types = (str, bool, int, type(None), float)
427-
container_types = (tuple, frozenset)
428-
if all(
429-
isinstance(v, simple_types)
430-
or (
431-
isinstance(v, container_types)
432-
and all(isinstance(i, simple_types) for i in v)
433-
)
434-
for v in props_dict.values()
435-
):
436-
hash = sha256(
437-
f"({type(op)}, {tuple(props_dict.items())}, {integer_str})".encode()
438-
).hexdigest()
439-
else:
440-
# Complex props, use pickle to serialize them
441-
hash = hash_from_pickle_dump(
442-
(str(type(op)), tuple(props_dict.items()), integer_str),
443-
)
449+
hash = default_hash_key_from_props(op, cache_version=integer_str)
444450
return func, hash
445451

446452

pytensor/link/numba/dispatch/typed_list.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from numba.types import Array, Boolean, List, Number
44

55
import pytensor.link.numba.dispatch.basic as numba_basic
6-
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
6+
from pytensor.link.numba.dispatch.basic import (
7+
default_hash_key_from_props,
8+
register_funcify_and_cache_key,
9+
register_funcify_default_op_cache_key,
10+
)
711
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
812
from pytensor.tensor.type_other import SliceType
913
from pytensor.typed_list import (
@@ -48,12 +52,14 @@ def all_equal(x, y):
4852
return False
4953
return True
5054

51-
if isinstance(x, Array) and isinstance(y, Array):
52-
55+
if (isinstance(x, Array) and x.ndim > 0) and (isinstance(y, Array) and y.ndim > 0):
56+
# (x == y).all() fails for 0d arrays
5357
def all_equal(x, y):
5458
return (x == y).all()
5559

56-
if isinstance(x, Number | Boolean) and isinstance(y, Number | Boolean):
60+
if (isinstance(x, Number | Boolean) or (isinstance(x, Array) and x.ndim == 0)) and (
61+
isinstance(y, Number | Boolean) or (isinstance(y, Array) and y.ndim == 0)
62+
):
5763

5864
def all_equal(x, y):
5965
return x == y
@@ -71,6 +77,16 @@ def deepcopy_list(x):
7177
return deepcopy_list
7278

7379

80+
def cache_key_if_not_inplace(op, inplace: bool):
81+
if inplace:
82+
# NUMBA is misbehaving with wrapped inplace ListType operations
83+
# which happens when we cache it in PyTensor
84+
# https://github.com/numba/numba/issues/10356
85+
return None
86+
else:
87+
return default_hash_key_from_props(op)
88+
89+
7490
@register_funcify_default_op_cache_key(MakeList)
7591
def numba_funcify_make_list(op, node, **kwargs):
7692
@numba_basic.numba_njit
@@ -108,7 +124,7 @@ def list_get_item_index(x, index):
108124
return list_get_item_index
109125

110126

111-
@register_funcify_default_op_cache_key(Reverse)
127+
@register_funcify_and_cache_key(Reverse)
112128
def numba_funcify_list_reverse(op, node, **kwargs):
113129
inplace = op.inplace
114130

@@ -121,10 +137,10 @@ def list_reverse(x):
121137
z.reverse()
122138
return z
123139

124-
return list_reverse
140+
return list_reverse, cache_key_if_not_inplace(op, inplace)
125141

126142

127-
@register_funcify_default_op_cache_key(Append)
143+
@register_funcify_and_cache_key(Append)
128144
def numba_funcify_list_append(op, node, **kwargs):
129145
inplace = op.inplace
130146

@@ -137,10 +153,10 @@ def list_append(x, to_append):
137153
z.append(numba_deepcopy(to_append))
138154
return z
139155

140-
return list_append
156+
return list_append, cache_key_if_not_inplace(op, inplace)
141157

142158

143-
@register_funcify_default_op_cache_key(Extend)
159+
@register_funcify_and_cache_key(Extend)
144160
def numba_funcify_list_extend(op, node, **kwargs):
145161
inplace = op.inplace
146162

@@ -153,10 +169,10 @@ def list_extend(x, to_append):
153169
z.extend(numba_deepcopy(to_append))
154170
return z
155171

156-
return list_extend
172+
return list_extend, cache_key_if_not_inplace(op, inplace)
157173

158174

159-
@register_funcify_default_op_cache_key(Insert)
175+
@register_funcify_and_cache_key(Insert)
160176
def numba_funcify_list_insert(op, node, **kwargs):
161177
inplace = op.inplace
162178

@@ -169,7 +185,7 @@ def list_insert(x, index, to_insert):
169185
z.insert(index.item(), numba_deepcopy(to_insert))
170186
return z
171187

172-
return list_insert
188+
return list_insert, cache_key_if_not_inplace(op, inplace)
173189

174190

175191
@register_funcify_default_op_cache_key(Index)
@@ -197,7 +213,7 @@ def list_count(x, elem):
197213
return list_count
198214

199215

200-
@register_funcify_default_op_cache_key(Remove)
216+
@register_funcify_and_cache_key(Remove)
201217
def numba_funcify_list_remove(op, node, **kwargs):
202218
inplace = op.inplace
203219

@@ -217,4 +233,4 @@ def list_remove(x, to_remove):
217233
z.pop(index_to_remove)
218234
return z
219235

220-
return list_remove
236+
return list_remove, cache_key_if_not_inplace(op, inplace)

tests/link/numba/test_typed_list.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

3-
from pytensor.tensor import matrix
4-
from pytensor.typed_list import make_list
3+
from pytensor import In
4+
from pytensor.tensor import as_tensor, lscalar, matrix
5+
from pytensor.typed_list import TypedListType, make_list
56
from tests.link.numba.test_basic import compare_numba_and_py
67

78

@@ -44,3 +45,32 @@ def test_make_list_find_ops():
4445
x_test = np.arange(12).reshape(3, 4)
4546
test_y = x_test[2]
4647
compare_numba_and_py([x, y], [l.ind(y), l.count(y), l.remove(y)], [x_test, test_y])
48+
49+
50+
def test_inplace_ops():
51+
int64_list = TypedListType(lscalar)
52+
ls = [int64_list(f"list[{i}]") for i in range(5)]
53+
to_extend = lscalar("to_extend")
54+
55+
ls_test = [np.arange(3, dtype="int64").tolist() for _ in range(5)]
56+
to_extend_test = np.array(99, dtype="int64")
57+
58+
def as_lscalar(x):
59+
return as_tensor(x, ndim=0, dtype="int64")
60+
61+
fn, _ = compare_numba_and_py(
62+
[*(In(l, mutable=True) for l in ls), to_extend],
63+
[
64+
ls[0].reverse(),
65+
ls[1].append(as_lscalar(99)),
66+
# This fails because it gets constant folded
67+
# ls_to_extend = make_list([as_lscalar(99), as_lscalar(100)])
68+
ls[2].extend(make_list([to_extend, to_extend + 1])),
69+
ls[3].insert(as_lscalar(1), as_lscalar(99)),
70+
ls[4].remove(as_lscalar(2)),
71+
],
72+
[*ls_test, to_extend_test],
73+
numba_mode="NUMBA", # So it triggers inplace
74+
)
75+
for out in fn.maker.fgraph.outputs:
76+
assert out.owner.op.destroy_map

0 commit comments

Comments
 (0)