Skip to content

Commit 14e7262

Browse files
committed
Numba Scan: prevent alias of outputs
Also simplified test. Shared variables aren't needed for the test and clobber it
1 parent 7a46ac2 commit 14e7262

File tree

3 files changed

+58
-29
lines changed

3 files changed

+58
-29
lines changed

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
6161
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
6262
accept_inplace=True,
6363
)
64+
# TODO: Prevent output aliasing like we do for Scan/outer function
6465
NUMBA.optimizer(fgraph)
6566
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
6667
op.fgraph, squeeze_output=True, **kwargs

pytensor/link/numba/dispatch/scan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
from numba import types
66
from numba.extending import overload
77

8-
from pytensor import In
9-
from pytensor.compile.function.types import add_supervisor_to_fgraph
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
9+
from pytensor.compile.io import In, Out
1010
from pytensor.compile.mode import NUMBA, get_mode
1111
from pytensor.link.numba.cache import compile_numba_function_src
1212
from pytensor.link.numba.dispatch import basic as numba_basic
1313
from pytensor.link.numba.dispatch.basic import (
14-
create_arg_string,
1514
create_tuple_string,
1615
numba_funcify_and_cache_key,
1716
register_funcify_and_cache_key,
@@ -89,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
8988
if outer_mitsot.type.shape[0] == abs(min(taps))
9089
]
9190
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
91+
input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs]
9292
add_supervisor_to_fgraph(
9393
fgraph=fgraph,
94-
input_specs=[
95-
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
96-
],
94+
input_specs=input_specs,
9795
accept_inplace=True,
9896
)
9997
rewriter(fgraph)
98+
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
99+
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
100100

101101
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
102102

tests/scan/test_basic.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3194,38 +3194,66 @@ def onestep(x, x_tm4):
31943194
f = function([seq], results[1])
31953195
assert np.all(exp_out == f(inp))
31963196

3197-
def test_shared_borrow(self):
3197+
@pytest.mark.parametrize("static_shape", (True, False)[:1])
3198+
def test_aliased_inner_outputs(self, static_shape):
31983199
"""
3199-
This tests two things. The first is a bug occurring when scan wrongly
3200-
used the borrow flag. The second thing it that Scan's infer_shape()
3201-
method will be able to remove the Scan node from the graph in this
3202-
case.
3200+
This tests two things. The first is a bug occurring when scan wrongly
3201+
used the borrow flag. The second thing it that Scan's infer_shape()
3202+
method will be able to remove the Scan node from the graph in this
3203+
case.
3204+
3205+
Here is pure python equivalent of the problem we want to avoid:
3206+
```python
3207+
def scan(seq, initval):
3208+
# Due to memory optimization we override values of mitsot as we iterate
3209+
# That's why mitsot has shape (4, 1) and not (14, 1)
3210+
mitsot = np.zeros((4, 1))
3211+
mitsot[:4] = initval
3212+
nitsot = np.zeros((10, 1))
3213+
for i, s in enumerate(seq):
3214+
# Incorrect results
3215+
mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
3216+
# Correct results
3217+
# mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
3218+
3219+
return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
3220+
3221+
scan(np.arange(10), np.zeros((4, 1)))
3222+
```
32033223
"""
32043224

3205-
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
3206-
exp_out = np.zeros((10, 1)).astype(config.floatX)
3207-
exp_out[4:] = inp[:-4]
3208-
3209-
def onestep(x, x_tm4):
3210-
return x, x_tm4
3211-
3212-
seq = matrix()
3213-
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
3214-
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
3215-
results = scan(
3216-
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
3225+
def onestep(seq, seq_tm4):
3226+
# Recurring output is just each value of seq
3227+
# And we further map the tap -4 as a new output
3228+
return seq, seq_tm4
3229+
3230+
# Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
3231+
# Otherwise we would be working with scalars and memory alias wouldn't be a concern
3232+
seq = matrix(shape=(10, 1) if static_shape else (None, None), name="seq")
3233+
init = matrix(shape=(4, 1) if static_shape else (None, None), name="init")
3234+
outputs_info = [{"initial": init, "taps": [-4]}, None]
3235+
[out_seq, out_seq_tm4] = scan(
3236+
fn=onestep,
3237+
sequences=seq,
3238+
outputs_info=outputs_info,
3239+
return_updates=False,
32173240
)
3218-
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
3219-
updates = {sharedvar: results[0][-1:]}
32203241

3221-
f = function([seq], results[1], updates=updates)
3242+
f = function([seq, init], [out_seq[-1].ravel(), out_seq_tm4.ravel()])
32223243

3223-
# This fails if scan uses wrongly the borrow flag
3224-
assert np.all(exp_out == f(inp))
3244+
seq_test_val = np.arange(10, dtype=config.floatX)[:, None]
3245+
init_test_val = np.zeros((4, 1), dtype=config.floatX)
3246+
3247+
res0, res1 = f(seq_test_val, init_test_val)
3248+
expected_res0 = np.array([9], dtype=config.floatX)
3249+
expected_res1 = np.zeros(10, dtype=config.floatX)
3250+
expected_res1[4:] = np.arange(6)
3251+
np.testing.assert_array_equal(res0, expected_res0)
3252+
np.testing.assert_array_equal(res1, expected_res1)
32253253

32263254
# This fails if Scan's infer_shape() is unable to remove the Scan
32273255
# node from the graph.
3228-
f_infershape = function([seq], results[1].shape, mode="FAST_RUN")
3256+
f_infershape = function([seq, init], out_seq_tm4[1].shape)
32293257
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
32303258
assert len(scan_nodes_infershape) == 0
32313259

0 commit comments

Comments
 (0)