Skip to content

Commit c0db447

Browse files
committed
Numba Scan: make codegen more readable
1 parent 14e7262 commit c0db447

File tree

3 files changed

+30
-35
lines changed

3 files changed

+30
-35
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,10 @@ def creator(args, creator=creator, i=i):
199199

200200

201201
def create_tuple_string(x):
202-
args = ", ".join(x + ([""] if len(x) == 1 else []))
203-
return f"({args})"
204-
205-
206-
def create_arg_string(x):
207-
args = ", ".join(x)
208-
return args
202+
if len(x) == 1:
203+
return f"({x[0]},)"
204+
else:
205+
return f"({', '.join(x)})"
209206

210207

211208
@numba.extending.intrinsic

pytensor/link/numba/dispatch/scan.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
151151
# Inner-inputs are ordered as follows:
152152
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
153153
# untraced-sit-sot-inputs + non-sequences.
154-
temp_scalar_storage_alloc_stmts: list[str] = []
154+
temp_0d_storage_alloc_stmts: list[str] = []
155155
inner_in_exprs_scalar: list[str] = []
156156
inner_in_exprs: list[str] = []
157157

@@ -169,7 +169,7 @@ def add_inner_in_expr(
169169
)
170170
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
171171
storage_dtype = outer_in_var.type.numpy_dtype.name
172-
temp_scalar_storage_alloc_stmts.append(
172+
temp_0d_storage_alloc_stmts.append(
173173
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
174174
)
175175
inner_in_exprs_scalar.append(
@@ -181,7 +181,7 @@ def add_inner_in_expr(
181181
storage_name
182182
if tap_offset is None
183183
else idx_to_str(
184-
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
184+
storage_name, tap_offset, size=storage_size_var, allow_scalar=True
185185
)
186186
)
187187
inner_in_exprs.append(indexed_inner_in_str)
@@ -366,23 +366,27 @@ def add_output_storage_post_proc_stmt(
366366
curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name)
367367
curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position]
368368

369-
storage_shape = create_tuple_string(
370-
[storage_size_name] + ["0"] * curr_nit_sot.ndim
371-
)
369+
known_static_shape = all(dim is not None for dim in curr_nit_sot.type.shape)
370+
if known_static_shape:
371+
storage_shape = create_tuple_string(
372+
(storage_size_name, *(map(str, curr_nit_sot.type.shape)))
373+
)
374+
else:
375+
storage_shape = create_tuple_string(
376+
(storage_size_name, *(["0"] * curr_nit_sot.ndim))
377+
)
372378
storage_dtype = curr_nit_sot.type.numpy_dtype.name
373379

374380
storage_alloc_stmts.append(
375381
dedent(
376382
f"""
377-
{storage_size_name} = ({outer_in_name}).item()
383+
{storage_size_name} = {outer_in_name}.item()
378384
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
379385
"""
380386
).strip()
381387
)
382388

383-
if curr_nit_sot.type.ndim > 0:
384-
storage_alloc_stmts.append(f"{outer_in_name}_ready = False")
385-
389+
if not known_static_shape:
386390
# In this case, we don't know the shape of the output storage
387391
# array until we get some output from the inner-function.
388392
# With the following we add delayed output storage initialization:
@@ -392,9 +396,8 @@ def add_output_storage_post_proc_stmt(
392396
inner_out_post_processing_stmts.append(
393397
dedent(
394398
f"""
395-
if not {outer_in_name}_ready:
399+
if i == 0:
396400
{storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype})
397-
{outer_in_name}_ready = True
398401
"""
399402
).strip()
400403
)
@@ -409,10 +412,11 @@ def add_output_storage_post_proc_stmt(
409412
assert len(inner_in_exprs) == len(op.fgraph.inputs)
410413

411414
inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
412-
inner_in_args = create_arg_string(inner_in_exprs)
415+
# Break inputs in new lines, just for readability of the source code
416+
inner_in_args = f",\n{' ' * 12}".join(inner_in_exprs)
413417
inner_outputs = create_tuple_string(inner_output_names)
414418
input_storage_block = "\n".join(storage_alloc_stmts)
415-
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
419+
input_temp_0d_storage_block = "\n".join(temp_0d_storage_alloc_stmts)
416420
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
417421
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
418422

@@ -426,32 +430,29 @@ def scan({", ".join(outer_in_names)}):
426430
427431
{indent(input_storage_block, " " * 4)}
428432
429-
{indent(input_temp_scalar_storage_block, " " * 4)}
433+
{indent(input_temp_0d_storage_block, " " * 4)}
430434
431435
i = 0
432436
cond = np.array(False)
433437
while i < n_steps and not cond.item():
434438
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
435439
436-
{inner_outputs} = scan_inner_func({inner_in_args})
440+
{inner_outputs} = scan_inner_func(
441+
{inner_in_args}
442+
)
437443
{indent(inner_out_post_processing_block, " " * 8)}
438444
{indent(inner_out_to_outer_out_stmts, " " * 8)}
439445
i += 1
440446
441447
{indent(output_storage_post_processing_block, " " * 4)}
442448
443-
return {create_arg_string(outer_output_names)}
449+
return {", ".join(outer_output_names)}
444450
"""
445451

446-
global_env = {
447-
"np": np,
448-
"scan_inner_func": scan_inner_func,
449-
}
450-
451452
scan_op_fn = compile_numba_function_src(
452453
scan_op_src,
453454
"scan",
454-
{**globals(), **global_env},
455+
globals() | {"np": np, "scan_inner_func": scan_inner_func},
455456
)
456457

457458
if inner_func_cache_key is None:

pytensor/link/numba/dispatch/shape.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44
from numba.np.unsafe import ndarray as numba_ndarray
55

66
from pytensor.link.numba.dispatch import basic as numba_basic
7-
from pytensor.link.numba.dispatch.basic import (
8-
create_arg_string,
9-
register_funcify_default_op_cache_key,
10-
)
7+
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
118
from pytensor.link.utils import compile_function_src
129
from pytensor.tensor import NoneConst
1310
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -48,7 +45,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
4845

4946
func = dedent(
5047
f"""
51-
def specify_shape(x, {create_arg_string(shape_input_names)}):
48+
def specify_shape(x, {", ".join(shape_input_names)}):
5249
{"; ".join(func_conditions)}
5350
return x
5451
"""

0 commit comments

Comments
 (0)