diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 23b790ecbc..881eb8123b 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -90,7 +90,7 @@ def jax_args_to_inner_func_args(carry, x): chain.from_iterable( buffer[(i + np.array(taps))] for buffer, taps in zip( - inner_mit_mot, info.mit_mot_in_slices, strict=True + inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True ) ) ) @@ -140,7 +140,10 @@ def inner_func_outs_to_jax_outs( new_mit_mot = [ buffer.at[i + np.array(taps)].set(new_vals) for buffer, new_vals, taps in zip( - old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True + old_mit_mot, + new_mit_mot_vals, + info.normalized_mit_mot_out_slices, + strict=True, ) ] # Discard oldest MIT-SOT and append newest value diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 5f18b9561f..6e6965f374 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -199,13 +199,10 @@ def creator(args, creator=creator, i=i): def create_tuple_string(x): - args = ", ".join(x + ([""] if len(x) == 1 else [])) - return f"({args})" - - -def create_arg_string(x): - args = ", ".join(x) - return args + if len(x) == 1: + return f"({x[0]},)" + else: + return f"({', '.join(x)})" @numba.extending.intrinsic diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 8eb73d0111..821ae66481 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -5,8 +5,8 @@ import numpy as np from pytensor.compile.builders import OpFromGraph -from pytensor.compile.function.types import add_supervisor_to_fgraph -from pytensor.compile.io import In +from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy +from pytensor.compile.io import In, Out from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse @@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): # explicitly triggers the optimization of the inner graphs of OpFromGraph? # The C-code defers it to the make_thunk phase fgraph = op.fgraph + input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, - input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], + input_specs=input_specs, accept_inplace=True, ) NUMBA.optimizer(fgraph) + output_specs = [Out(o, borrow=False) for o in fgraph.outputs] + insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( - op.fgraph, squeeze_output=True, **kwargs + fgraph, squeeze_output=True, **kwargs ) if fgraph_cache_key is None: diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index d4064cb6c1..ffbbb296ed 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -5,13 +5,12 @@ from numba import types from numba.extending import overload -from pytensor import In -from pytensor.compile.function.types import add_supervisor_to_fgraph +from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy +from pytensor.compile.io import In, Out from pytensor.compile.mode import NUMBA, get_mode from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( - create_arg_string, create_tuple_string, numba_funcify_and_cache_key, register_funcify_and_cache_key, @@ -27,9 +26,8 @@ def idx_to_str( idx_symbol: str = "i", allow_scalar=False, ) -> str: - if offset < 0: - indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" - elif offset > 0: + assert offset >= 0 + if offset > 0: indices = f"{idx_symbol} + {offset}" else: indices = idx_symbol @@ -90,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): if outer_mitsot.type.shape[0] == abs(min(taps)) ] destroyable = {*destroyable_sitsot, *destroyable_mitsot} + input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, - input_specs=[ - In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs - ], + input_specs=input_specs, accept_inplace=True, ) rewriter(fgraph) + output_specs = [Out(x, borrow=False) for x in fgraph.outputs] + insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph) @@ -152,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): # Inner-inputs are ordered as follows: # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # untraced-sit-sot-inputs + non-sequences. - temp_scalar_storage_alloc_stmts: list[str] = [] + temp_0d_storage_alloc_stmts: list[str] = [] inner_in_exprs_scalar: list[str] = [] inner_in_exprs: list[str] = [] @@ -170,7 +169,7 @@ def add_inner_in_expr( ) temp_storage = f"{storage_name}_temp_scalar_{tap_offset}" storage_dtype = outer_in_var.type.numpy_dtype.name - temp_scalar_storage_alloc_stmts.append( + temp_0d_storage_alloc_stmts.append( f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})" ) inner_in_exprs_scalar.append( @@ -182,7 +181,7 @@ def add_inner_in_expr( storage_name if tap_offset is None else idx_to_str( - storage_name, tap_offset, size=storage_size_var, allow_scalar=False + storage_name, tap_offset, size=storage_size_var, allow_scalar=True ) ) inner_in_exprs.append(indexed_inner_in_str) @@ -226,33 +225,16 @@ def add_inner_in_expr( # storage array like a circular buffer, and that's why we need to track the # storage size along with the taps length/indexing offset. def add_output_storage_post_proc_stmt( - outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str + outer_in_name: str, max_offset: int, storage_size: str ): - tap_size = max(tap_sizes) - - if op.info.as_while: - # While loops need to truncate the output storage to a length given - # by the number of iterations performed. - output_storage_post_proc_stmts.append( - dedent( - f""" - if i + {tap_size} < {storage_size}: - {storage_size} = i + {tap_size} - {outer_in_name} = {outer_in_name}[:{storage_size}] - """ - ).strip() - ) - - # Rotate the storage so that the last computed value is at the end of - # the storage array. + # Rotate the storage so that the last computed value is at the end of the storage array. # This is needed when the output storage array does not have a length # equal to the number of taps plus `n_steps`. - # If the storage size only allows one entry, there's nothing to rotate output_storage_post_proc_stmts.append( dedent( f""" - if 1 < {storage_size} < (i + {tap_size}): - {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) + if 1 < {storage_size} < (i + {max_offset}): + {outer_in_name}_shift = (i + {max_offset}) % ({storage_size}) if {outer_in_name}_shift > 0: {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] @@ -261,6 +243,29 @@ def add_output_storage_post_proc_stmt( ).strip() ) + if op.info.as_while: + # While loops need to truncate the output storage to a length given + # by the number of iterations performed. + output_storage_post_proc_stmts.append( + dedent( + f""" + elif {storage_size} > (i + {max_offset}): + {outer_in_name} = {outer_in_name}[:i + {max_offset}] + """ + ).strip() + ) + else: + # And regular loops should zero out unused entries of the output buffer + # These show up with truncated gradients of while loops + output_storage_post_proc_stmts.append( + dedent( + f""" + elif {storage_size} > (i + {max_offset}): + {outer_in_name}[i + {max_offset}:] = 0 + """ + ).strip() + ) + # Special in-loop statements that create (nit-sot) storage arrays after a # single iteration is performed. This is necessary because we don't know # the exact shapes of the storage arrays that need to be allocated until @@ -288,12 +293,11 @@ def add_output_storage_post_proc_stmt( storage_size_name = f"{outer_in_name}_len" storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]" input_taps = inner_in_names_to_input_taps[outer_in_name] - tap_storage_size = -min(input_taps) - assert tap_storage_size >= 0 + max_lookback_inp_tap = -min(0, min(input_taps)) + assert max_lookback_inp_tap >= 0 for in_tap in input_taps: - tap_offset = in_tap + tap_storage_size - assert tap_offset >= 0 + tap_offset = max_lookback_inp_tap + in_tap is_vector = outer_in_var.ndim == 1 add_inner_in_expr( outer_in_name, @@ -302,22 +306,25 @@ def add_output_storage_post_proc_stmt( vector_slice_opt=is_vector, ) - output_taps = inner_in_names_to_output_taps.get( - outer_in_name, [tap_storage_size] - ) - inner_out_to_outer_in_stmts.extend( - idx_to_str( - storage_name, - out_tap, - size=storage_size_name, - allow_scalar=True, + output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0]) + for out_tap in output_taps: + tap_offset = max_lookback_inp_tap + out_tap + assert tap_offset >= 0 + inner_out_to_outer_in_stmts.append( + idx_to_str( + storage_name, + tap_offset, + size=storage_size_name, + allow_scalar=True, + ) ) - for out_tap in output_taps - ) - add_output_storage_post_proc_stmt( - storage_name, output_taps, storage_size_name - ) + if outer_in_name not in outer_in_mit_mot_names: + # MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop + max_offset_out_tap = max(output_taps) + max_lookback_inp_tap + add_output_storage_post_proc_stmt( + storage_name, max_offset_out_tap, storage_size_name + ) else: storage_size_stmt = "" @@ -351,7 +358,7 @@ def add_output_storage_post_proc_stmt( inner_out_to_outer_in_stmts.append( idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True) ) - add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) + add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name) # In case of nit-sots we are provided the length of the array in # the iteration dimension instead of actual arrays, hence we @@ -359,23 +366,27 @@ def add_output_storage_post_proc_stmt( curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name) curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position] - storage_shape = create_tuple_string( - [storage_size_name] + ["0"] * curr_nit_sot.ndim - ) + known_static_shape = all(dim is not None for dim in curr_nit_sot.type.shape) + if known_static_shape: + storage_shape = create_tuple_string( + (storage_size_name, *(map(str, curr_nit_sot.type.shape))) + ) + else: + storage_shape = create_tuple_string( + (storage_size_name, *(["0"] * curr_nit_sot.ndim)) + ) storage_dtype = curr_nit_sot.type.numpy_dtype.name storage_alloc_stmts.append( dedent( f""" - {storage_size_name} = ({outer_in_name}).item() + {storage_size_name} = {outer_in_name}.item() {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) """ ).strip() ) - if curr_nit_sot.type.ndim > 0: - storage_alloc_stmts.append(f"{outer_in_name}_ready = False") - + if not known_static_shape: # In this case, we don't know the shape of the output storage # array until we get some output from the inner-function. # With the following we add delayed output storage initialization: @@ -385,9 +396,8 @@ def add_output_storage_post_proc_stmt( inner_out_post_processing_stmts.append( dedent( f""" - if not {outer_in_name}_ready: + if i == 0: {storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype}) - {outer_in_name}_ready = True """ ).strip() ) @@ -402,10 +412,11 @@ def add_output_storage_post_proc_stmt( assert len(inner_in_exprs) == len(op.fgraph.inputs) inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar) - inner_in_args = create_arg_string(inner_in_exprs) + # Break inputs in new lines, just for readability of the source code + inner_in_args = f",\n{' ' * 12}".join(inner_in_exprs) inner_outputs = create_tuple_string(inner_output_names) input_storage_block = "\n".join(storage_alloc_stmts) - input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts) + input_temp_0d_storage_block = "\n".join(temp_0d_storage_alloc_stmts) output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) @@ -419,32 +430,29 @@ def scan({", ".join(outer_in_names)}): {indent(input_storage_block, " " * 4)} -{indent(input_temp_scalar_storage_block, " " * 4)} +{indent(input_temp_0d_storage_block, " " * 4)} i = 0 cond = np.array(False) while i < n_steps and not cond.item(): {indent(inner_scalar_in_args_to_temp_storage, " " * 8)} - {inner_outputs} = scan_inner_func({inner_in_args}) + {inner_outputs} = scan_inner_func( + {inner_in_args} + ) {indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)} i += 1 {indent(output_storage_post_processing_block, " " * 4)} - return {create_arg_string(outer_output_names)} + return {", ".join(outer_output_names)} """ - global_env = { - "np": np, - "scan_inner_func": scan_inner_func, - } - scan_op_fn = compile_numba_function_src( scan_op_src, "scan", - {**globals(), **global_env}, + globals() | {"np": np, "scan_inner_func": scan_inner_func}, ) if inner_func_cache_key is None: diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index ed286dd889..b6a5533809 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -4,10 +4,7 @@ from numba.np.unsafe import ndarray as numba_ndarray from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import ( - create_arg_string, - register_funcify_default_op_cache_key, -) +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key from pytensor.link.utils import compile_function_src from pytensor.tensor import NoneConst from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): func = dedent( f""" - def specify_shape(x, {create_arg_string(shape_input_names)}): + def specify_shape(x, {", ".join(shape_input_names)}): {"; ".join(func_conditions)} return x """ diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 7521fd3828..6efeddc8bb 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -288,6 +288,26 @@ def n_outer_outputs(self): + self.n_untraced_sit_sot_outs ) + @property + def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]: + """Return mit_mot_in slices normalized as an offset from the oldest tap""" + # TODO: Make this the canonical representation + res = [] + for in_slice in self.mit_mot_in_slices: + min_tap = -(min(0, min(in_slice))) + res.append(tuple(tap + min_tap for tap in in_slice)) + return tuple(res) + + @property + def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]: + """Return mit_mot_out slices normalized as an offset from the oldest tap""" + # TODO: Make this the canonical representation + res = [] + for out_slice in self.mit_mot_out_slices: + min_tap = -(min(0, min(out_slice))) + res.append(tuple(tap + min_tap for tap in out_slice)) + return tuple(res) + TensorConstructorType = Callable[ [Iterable[bool | int | None], str | np.generic], TensorType diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index e180d34fd7..01a12914fd 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -15,6 +15,7 @@ from pytensor.tensor.math import gammaln, log from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py +from tests.scan.test_basic import ScanCompatibilityTests jax = pytest.importorskip("jax") @@ -626,3 +627,7 @@ def block_until_ready(*inputs, jax_fn=jax_fn): block_until_ready(*test_input_vals) # Warmup benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1) + + +def test_higher_order_derivatives(): + ScanCompatibilityTests.check_higher_order_derivative(mode="JAX") diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index b51b359a08..918b4324d3 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -5,6 +5,9 @@ from pytensor import tensor as pt from pytensor.compile import ViewOp from pytensor.raise_op import assert_op +from pytensor.scalar import Add +from pytensor.tensor import matrix +from pytensor.tensor.elemwise import Elemwise from tests.link.numba.test_basic import compare_numba_and_py @@ -146,6 +149,28 @@ def test_ofg_inner_inplace(): np.testing.assert_allclose(res1, [1, np.e, np.e]) +def test_ofg_aliased_outputs(): + x = matrix("x") + # Create multiple views of x + outs = OpFromGraph([x], [x, x.T, x[::-1]])(x) + # Add one to each x, which when inplace shouldn't propagate across outputs + bumped_outs = [o + 1 for o in outs] + fn = function([x], bumped_outs, mode="NUMBA") + fn.dprint(print_destroy_map=True) + # Check our outputs are indeed inplace adds + assert all( + ( + isinstance(o.owner.op, Elemwise) + and isinstance(o.owner.op.scalar_op, Add) + and o.owner.op.destroy_map + ) + for o in fn.maker.fgraph.outputs + ) + x_test = np.zeros((2, 2)) + for res in fn(x_test): + np.testing.assert_allclose(res, np.ones((2, 2))) + + def test_check_and_raise(): x = pt.vector() x_test_value = np.array([1.0, 2.0], dtype=config.floatX) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 77ceebbcf7..b4448ad06c 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -16,6 +16,7 @@ from pytensor.tensor.random.utils import RandomStream from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py +from tests.scan.test_basic import ScanCompatibilityTests @pytest.mark.parametrize( @@ -652,3 +653,15 @@ def test_mit_sot_buffer(self, constant_n_steps, n_steps_val): def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark): self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark) + + +def test_higher_order_derivatives(): + ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA") + + +def test_grad_until_and_truncate_sequence_taps(): + ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA") + + +def test_aliased_inner_outputs(): + ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA") diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index d4f3e1bde1..f1bbdb6d95 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -2621,22 +2621,7 @@ def test_grad_until_and_truncate(self): utt.assert_allclose(pytensor_gradient, self.numpy_gradient) def test_grad_until_and_truncate_sequence_taps(self): - n = 3 - r = scan( - lambda x, y, u: (x * y, until(y > u)), - sequences=dict(input=self.x, taps=[-2, 0]), - non_sequences=[self.threshold], - truncate_gradient=n, - return_updates=False, - ) - g = grad(r.sum(), self.x) - f = function([self.x, self.threshold], [r, g]) - _pytensor_output, pytensor_gradient = f(self.seq, 6) - - # Gradient computed by hand: - numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0]) - numpy_grad = numpy_grad.astype(config.floatX) - utt.assert_allclose(pytensor_gradient, numpy_grad) + ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode=None) def test_mintap_onestep(): @@ -3196,40 +3181,9 @@ def onestep(x, x_tm4): f = function([seq], results[1]) assert np.all(exp_out == f(inp)) - def test_shared_borrow(self): - """ - This tests two things. The first is a bug occurring when scan wrongly - used the borrow flag. The second thing it that Scan's infer_shape() - method will be able to remove the Scan node from the graph in this - case. - """ - - inp = np.arange(10).reshape(-1, 1).astype(config.floatX) - exp_out = np.zeros((10, 1)).astype(config.floatX) - exp_out[4:] = inp[:-4] - - def onestep(x, x_tm4): - return x, x_tm4 - - seq = matrix() - initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) - outputs_info = [{"initial": initial_value, "taps": [-4]}, None] - results = scan( - fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False - ) - sharedvar = shared(np.zeros((1, 1), dtype=config.floatX)) - updates = {sharedvar: results[0][-1:]} - - f = function([seq], results[1], updates=updates) - - # This fails if scan uses wrongly the borrow flag - assert np.all(exp_out == f(inp)) - - # This fails if Scan's infer_shape() is unable to remove the Scan - # node from the graph. - f_infershape = function([seq], results[1].shape, mode="FAST_RUN") - scan_nodes_infershape = scan_nodes_from_fct(f_infershape) - assert len(scan_nodes_infershape) == 0 + @pytest.mark.parametrize("static_shape", (True, False)) + def test_aliased_inner_outputs(self, static_shape): + ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode=None) def test_memory_reuse_with_outputs_as_inputs(self): """ @@ -4082,6 +4036,9 @@ def test_grad_multiple_outs_some_disconnected_2(self): # Also, the purpose of this test is not clear. self._grad_mout_helper(1, None) + def test_higher_order_derivatives(self): + ScanCompatibilityTests.check_higher_order_derivative(mode=None) + @pytest.mark.parametrize( "fn, sequences, outputs_info, non_sequences, n_steps, op_check", @@ -4398,3 +4355,118 @@ def test_scan_mode_compatibility(scan_mode): # Expected value computed by running correct Scan once np.testing.assert_allclose(fn(*numerical_inputs), [44, 38]) + + +class ScanCompatibilityTests: + """Collection of test of subtle required behaviors of Scan, that can be reused by different backends.""" + + @staticmethod + def check_higher_order_derivative(mode): + """This tests different mit-mot taps signs""" + x = pt.dscalar("x") + + # xs[-1] is equivalent to x ** 16 + xs = scan( + fn=lambda xtm1: xtm1**2, + outputs_info=[x], + n_steps=4, + return_updates=False, + ) + r = xs[-1] + g = grad(r, x) + gg = grad(g, x) + ggg = grad(gg, x) + + fn = function([x], [r, g, gg, ggg], mode=mode) + x_test = np.array(0.95, dtype=x.type.dtype) + r_res, g_res, gg_res, _ggg_res = fn(x_test) + np.testing.assert_allclose(r_res, x_test**16) + np.testing.assert_allclose(g_res, 16 * x_test**15) + np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14) + # FIXME: All implementations of Scan seem to get this one wrong! + # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13) + + @staticmethod + def check_grad_until_and_truncate_sequence_taps(mode): + """Test case where we need special behavior of zeroing out sequences in Scan""" + x = pt.vector("x") + threshold = pt.scalar(name="threshold", dtype="int64") + + r = scan( + lambda x, y, u: (x * y, until(y > u)), + sequences=dict(input=x, taps=[-2, 0]), + non_sequences=[threshold], + truncate_gradient=3, + return_updates=False, + ) + g = grad(r.sum(), x) + f = function([x, threshold], [r, g], mode=mode) + _, grad_res = f(np.arange(15, dtype=x.dtype), 6) + + # Gradient computed by hand: + grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0]) + grad_expected = grad_expected.astype(config.floatX) + np.testing.assert_allclose(grad_res, grad_expected) + + @staticmethod + def check_aliased_inner_outputs(static_shape, mode): + """ + This tests two things. The first is a bug occurring when scan wrongly + used the borrow flag. The second thing it that Scan's infer_shape() + method will be able to remove the Scan node from the graph in this + case. + + Here is pure python equivalent of the problem we want to avoid: + ```python + def scan(seq, initval): + # Due to memory optimization we override values of mitsot as we iterate + # That's why mitsot has shape (4, 1) and not (14, 1) + mitsot = np.zeros((4, 1)) + mitsot[:4] = initval + nitsot = np.zeros((10, 1)) + for i, s in enumerate(seq): + # Incorrect results + mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4] + # Correct results + # mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy() + + return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot + + scan(np.arange(10), np.zeros((4, 1))) + ``` + """ + + def onestep(seq, seq_tm4): + # Recurring output is just each value of seq + # And we further map the tap -4 as a new output + return seq, seq_tm4 + + # Outer tensors must be atleast matrix, so that they we have vectors in the inner loop + # Otherwise we would be working with scalars and memory alias wouldn't be a concern + seq = matrix(shape=(10, 1) if static_shape else (None, None), name="seq") + init = matrix(shape=(4, 1) if static_shape else (None, None), name="init") + outputs_info = [{"initial": init, "taps": [-4]}, None] + [out_seq, out_seq_tm4] = scan( + fn=onestep, + sequences=seq, + outputs_info=outputs_info, + return_updates=False, + ) + + f = function([seq, init], [out_seq[-1].ravel(), out_seq_tm4.ravel()], mode=mode) + + seq_test_val = np.arange(10, dtype=config.floatX)[:, None] + init_test_val = np.zeros((4, 1), dtype=config.floatX) + + res0, res1 = f(seq_test_val, init_test_val) + expected_res0 = np.array([9], dtype=config.floatX) + expected_res1 = np.zeros(10, dtype=config.floatX) + expected_res1[4:] = np.arange(6) + np.testing.assert_array_equal(res0, expected_res0) + np.testing.assert_array_equal(res1, expected_res1) + + # This fails if Scan's infer_shape() is unable to remove the Scan + # node from the graph. + f_infershape = function([seq, init], out_seq_tm4[1].shape) + scan_nodes_infershape = scan_nodes_from_fct(f_infershape) + assert len(scan_nodes_infershape) == 0