Skip to content

Commit 4186819

Browse files
committed
Numba OpFromGraph: Prevent alias of outputs
1 parent 57b344e commit 4186819

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numpy as np
66

77
from pytensor.compile.builders import OpFromGraph
8-
from pytensor.compile.function.types import add_supervisor_to_fgraph
9-
from pytensor.compile.io import In
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
9+
from pytensor.compile.io import In, Out
1010
from pytensor.compile.mode import NUMBA
1111
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1212
from pytensor.ifelse import IfElse
@@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
5656
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
5757
# The C-code defers it to the make_thunk phase
5858
fgraph = op.fgraph
59+
input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs]
5960
add_supervisor_to_fgraph(
6061
fgraph=fgraph,
61-
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
62+
input_specs=input_specs,
6263
accept_inplace=True,
6364
)
6465
NUMBA.optimizer(fgraph)
66+
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
67+
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
6568
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
66-
op.fgraph, squeeze_output=True, **kwargs
69+
fgraph, squeeze_output=True, **kwargs
6770
)
6871

6972
if fgraph_cache_key is None:

tests/link/numba/test_compile_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pytensor import tensor as pt
66
from pytensor.compile import ViewOp
77
from pytensor.raise_op import assert_op
8+
from pytensor.scalar import Add
9+
from pytensor.tensor import matrix
10+
from pytensor.tensor.elemwise import Elemwise
811
from tests.link.numba.test_basic import compare_numba_and_py
912

1013

@@ -146,6 +149,28 @@ def test_ofg_inner_inplace():
146149
np.testing.assert_allclose(res1, [1, np.e, np.e])
147150

148151

152+
def test_ofg_aliased_outputs():
153+
x = matrix("x")
154+
# Create multiple views of x
155+
outs = OpFromGraph([x], [x, x.T, x[::-1]])(x)
156+
# Add one to each x, which when inplace shouldn't propagate across outputs
157+
bumped_outs = [o + 1 for o in outs]
158+
fn = function([x], bumped_outs, mode="NUMBA")
159+
fn.dprint(print_destroy_map=True)
160+
# Check our outputs are indeed inplace adds
161+
assert all(
162+
(
163+
isinstance(o.owner.op, Elemwise)
164+
and isinstance(o.owner.op.scalar_op, Add)
165+
and o.owner.op.destroy_map
166+
)
167+
for o in fn.maker.fgraph.outputs
168+
)
169+
x_test = np.zeros((2, 2))
170+
for res in fn(x_test):
171+
np.testing.assert_allclose(res, np.ones((2, 2)))
172+
173+
149174
def test_check_and_raise():
150175
x = pt.vector()
151176
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)

0 commit comments

Comments
 (0)