Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def jax_args_to_inner_func_args(carry, x):
chain.from_iterable(
buffer[(i + np.array(taps))]
for buffer, taps in zip(
inner_mit_mot, info.mit_mot_in_slices, strict=True
inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
)
)
)
Expand Down Expand Up @@ -140,7 +140,10 @@ def inner_func_outs_to_jax_outs(
new_mit_mot = [
buffer.at[i + np.array(taps)].set(new_vals)
for buffer, new_vals, taps in zip(
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
old_mit_mot,
new_mit_mot_vals,
info.normalized_mit_mot_out_slices,
strict=True,
)
]
# Discard oldest MIT-SOT and append newest value
Expand Down
11 changes: 4 additions & 7 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,10 @@ def creator(args, creator=creator, i=i):


def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"


def create_arg_string(x):
args = ", ".join(x)
return args
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"


@numba.extending.intrinsic
Expand Down
11 changes: 7 additions & 4 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
fgraph = op.fgraph
input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs]
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
input_specs=input_specs,
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
fgraph, squeeze_output=True, **kwargs
)

if fgraph_cache_key is None:
Expand Down
152 changes: 80 additions & 72 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from numba import types
from numba.extending import overload

from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
Expand All @@ -27,9 +26,8 @@ def idx_to_str(
idx_symbol: str = "i",
allow_scalar=False,
) -> str:
if offset < 0:
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
elif offset > 0:
assert offset >= 0
if offset > 0:
indices = f"{idx_symbol} + {offset}"
else:
indices = idx_symbol
Expand Down Expand Up @@ -90,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
if outer_mitsot.type.shape[0] == abs(min(taps))
]
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs]
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
],
input_specs=input_specs,
accept_inplace=True,
)
rewriter(fgraph)
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)

Expand Down Expand Up @@ -152,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = []
temp_0d_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = []

Expand All @@ -170,7 +169,7 @@ def add_inner_in_expr(
)
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append(
temp_0d_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
)
inner_in_exprs_scalar.append(
Expand All @@ -182,7 +181,7 @@ def add_inner_in_expr(
storage_name
if tap_offset is None
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
storage_name, tap_offset, size=storage_size_var, allow_scalar=True
)
)
inner_in_exprs.append(indexed_inner_in_str)
Expand Down Expand Up @@ -226,33 +225,16 @@ def add_inner_in_expr(
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
outer_in_name: str, max_offset: int, storage_size: str
):
tap_size = max(tap_sizes)

if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)

# Rotate the storage so that the last computed value is at the end of
# the storage array.
# Rotate the storage so that the last computed value is at the end of the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts.append(
dedent(
f"""
if 1 < {storage_size} < (i + {tap_size}):
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if 1 < {storage_size} < (i + {max_offset}):
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
Expand All @@ -261,6 +243,29 @@ def add_output_storage_post_proc_stmt(
).strip()
)

if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
"""
).strip()
)
else:
# And regular loops should zero out unused entries of the output buffer
# These show up with truncated gradients of while loops
output_storage_post_proc_stmts.append(
dedent(
f"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name}[i + {max_offset}:] = 0
"""
).strip()
)

# Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until
Expand Down Expand Up @@ -288,12 +293,11 @@ def add_output_storage_post_proc_stmt(
storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
max_lookback_inp_tap = -min(0, min(input_taps))
assert max_lookback_inp_tap >= 0

for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
tap_offset = max_lookback_inp_tap + in_tap
is_vector = outer_in_var.ndim == 1
add_inner_in_expr(
outer_in_name,
Expand All @@ -302,22 +306,25 @@ def add_output_storage_post_proc_stmt(
vector_slice_opt=is_vector,
)

output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
inner_out_to_outer_in_stmts.extend(
idx_to_str(
storage_name,
out_tap,
size=storage_size_name,
allow_scalar=True,
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
for out_tap in output_taps:
tap_offset = max_lookback_inp_tap + out_tap
assert tap_offset >= 0
inner_out_to_outer_in_stmts.append(
idx_to_str(
storage_name,
tap_offset,
size=storage_size_name,
allow_scalar=True,
)
)
for out_tap in output_taps
)

add_output_storage_post_proc_stmt(
storage_name, output_taps, storage_size_name
)
if outer_in_name not in outer_in_mit_mot_names:
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
add_output_storage_post_proc_stmt(
storage_name, max_offset_out_tap, storage_size_name
)

else:
storage_size_stmt = ""
Expand Down Expand Up @@ -351,31 +358,35 @@ def add_output_storage_post_proc_stmt(
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
)
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)

# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
# allocate space for the results accordingly.
curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]

storage_shape = create_tuple_string(
[storage_size_name] + ["0"] * curr_nit_sot.ndim
)
known_static_shape = all(dim is not None for dim in curr_nit_sot.type.shape)
if known_static_shape:
storage_shape = create_tuple_string(
(storage_size_name, *(map(str, curr_nit_sot.type.shape)))
)
else:
storage_shape = create_tuple_string(
(storage_size_name, *(["0"] * curr_nit_sot.ndim))
)
storage_dtype = curr_nit_sot.type.numpy_dtype.name

storage_alloc_stmts.append(
dedent(
f"""
{storage_size_name} = ({outer_in_name}).item()
{storage_size_name} = {outer_in_name}.item()
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
)

if curr_nit_sot.type.ndim > 0:
storage_alloc_stmts.append(f"{outer_in_name}_ready = False")

if not known_static_shape:
# In this case, we don't know the shape of the output storage
# array until we get some output from the inner-function.
# With the following we add delayed output storage initialization:
Expand All @@ -385,9 +396,8 @@ def add_output_storage_post_proc_stmt(
inner_out_post_processing_stmts.append(
dedent(
f"""
if not {outer_in_name}_ready:
if i == 0:
{storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
{outer_in_name}_ready = True
"""
).strip()
)
Expand All @@ -402,10 +412,11 @@ def add_output_storage_post_proc_stmt(
assert len(inner_in_exprs) == len(op.fgraph.inputs)

inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
inner_in_args = create_arg_string(inner_in_exprs)
# Break inputs in new lines, just for readability of the source code
inner_in_args = f",\n{' ' * 12}".join(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(storage_alloc_stmts)
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
input_temp_0d_storage_block = "\n".join(temp_0d_storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)

Expand All @@ -419,32 +430,29 @@ def scan({", ".join(outer_in_names)}):

{indent(input_storage_block, " " * 4)}

{indent(input_temp_scalar_storage_block, " " * 4)}
{indent(input_temp_0d_storage_block, " " * 4)}

i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}

{inner_outputs} = scan_inner_func({inner_in_args})
{inner_outputs} = scan_inner_func(
{inner_in_args}
)
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
i += 1

{indent(output_storage_post_processing_block, " " * 4)}

return {create_arg_string(outer_output_names)}
return {", ".join(outer_output_names)}
"""

global_env = {
"np": np,
"scan_inner_func": scan_inner_func,
}

scan_op_fn = compile_numba_function_src(
scan_op_src,
"scan",
{**globals(), **global_env},
globals() | {"np": np, "scan_inner_func": scan_inner_func},
)

if inner_func_cache_key is None:
Expand Down
7 changes: 2 additions & 5 deletions pytensor/link/numba/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from numba.np.unsafe import ndarray as numba_ndarray

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
Expand Down Expand Up @@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):

func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
def specify_shape(x, {", ".join(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
Expand Down
Loading
Loading