@@ -151,7 +151,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
151151 # Inner-inputs are ordered as follows:
152152 # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
153153 # untraced-sit-sot-inputs + non-sequences.
154- temp_scalar_storage_alloc_stmts : list [str ] = []
154+ temp_0d_storage_alloc_stmts : list [str ] = []
155155 inner_in_exprs_scalar : list [str ] = []
156156 inner_in_exprs : list [str ] = []
157157
@@ -169,7 +169,7 @@ def add_inner_in_expr(
169169 )
170170 temp_storage = f"{ storage_name } _temp_scalar_{ tap_offset } "
171171 storage_dtype = outer_in_var .type .numpy_dtype .name
172- temp_scalar_storage_alloc_stmts .append (
172+ temp_0d_storage_alloc_stmts .append (
173173 f"{ temp_storage } = np.empty((), dtype=np.{ storage_dtype } )"
174174 )
175175 inner_in_exprs_scalar .append (
@@ -181,7 +181,7 @@ def add_inner_in_expr(
181181 storage_name
182182 if tap_offset is None
183183 else idx_to_str (
184- storage_name , tap_offset , size = storage_size_var , allow_scalar = False
184+ storage_name , tap_offset , size = storage_size_var , allow_scalar = True
185185 )
186186 )
187187 inner_in_exprs .append (indexed_inner_in_str )
@@ -366,23 +366,27 @@ def add_output_storage_post_proc_stmt(
366366 curr_nit_sot_position = outer_in_nit_sot_names .index (outer_in_name )
367367 curr_nit_sot = op .inner_nitsot_outs (op .inner_outputs )[curr_nit_sot_position ]
368368
369- storage_shape = create_tuple_string (
370- [storage_size_name ] + ["0" ] * curr_nit_sot .ndim
371- )
369+ known_static_shape = all (dim is not None for dim in curr_nit_sot .type .shape )
370+ if known_static_shape :
371+ storage_shape = create_tuple_string (
372+ (storage_size_name , * (map (str , curr_nit_sot .type .shape )))
373+ )
374+ else :
375+ storage_shape = create_tuple_string (
376+ (storage_size_name , * (["0" ] * curr_nit_sot .ndim ))
377+ )
372378 storage_dtype = curr_nit_sot .type .numpy_dtype .name
373379
374380 storage_alloc_stmts .append (
375381 dedent (
376382 f"""
377- { storage_size_name } = ( { outer_in_name } ) .item()
383+ { storage_size_name } = { outer_in_name } .item()
378384 { storage_name } = np.empty({ storage_shape } , dtype=np.{ storage_dtype } )
379385 """
380386 ).strip ()
381387 )
382388
383- if curr_nit_sot .type .ndim > 0 :
384- storage_alloc_stmts .append (f"{ outer_in_name } _ready = False" )
385-
389+ if not known_static_shape :
386390 # In this case, we don't know the shape of the output storage
387391 # array until we get some output from the inner-function.
388392 # With the following we add delayed output storage initialization:
@@ -392,9 +396,8 @@ def add_output_storage_post_proc_stmt(
392396 inner_out_post_processing_stmts .append (
393397 dedent (
394398 f"""
395- if not { outer_in_name } _ready :
399+ if i == 0 :
396400 { storage_name } = np.empty(({ storage_size_name } ,) + np.shape({ inner_out_name } ), dtype=np.{ storage_dtype } )
397- { outer_in_name } _ready = True
398401 """
399402 ).strip ()
400403 )
@@ -409,10 +412,11 @@ def add_output_storage_post_proc_stmt(
409412 assert len (inner_in_exprs ) == len (op .fgraph .inputs )
410413
411414 inner_scalar_in_args_to_temp_storage = "\n " .join (inner_in_exprs_scalar )
412- inner_in_args = create_arg_string (inner_in_exprs )
415+ # Break inputs in new lines, just for readability of the source code
416+ inner_in_args = f",\n { ' ' * 12 } " .join (inner_in_exprs )
413417 inner_outputs = create_tuple_string (inner_output_names )
414418 input_storage_block = "\n " .join (storage_alloc_stmts )
415- input_temp_scalar_storage_block = "\n " .join (temp_scalar_storage_alloc_stmts )
419+ input_temp_0d_storage_block = "\n " .join (temp_0d_storage_alloc_stmts )
416420 output_storage_post_processing_block = "\n " .join (output_storage_post_proc_stmts )
417421 inner_out_post_processing_block = "\n " .join (inner_out_post_processing_stmts )
418422
@@ -426,32 +430,29 @@ def scan({", ".join(outer_in_names)}):
426430
427431{ indent (input_storage_block , " " * 4 )}
428432
429- { indent (input_temp_scalar_storage_block , " " * 4 )}
433+ { indent (input_temp_0d_storage_block , " " * 4 )}
430434
431435 i = 0
432436 cond = np.array(False)
433437 while i < n_steps and not cond.item():
434438{ indent (inner_scalar_in_args_to_temp_storage , " " * 8 )}
435439
436- { inner_outputs } = scan_inner_func({ inner_in_args } )
440+ { inner_outputs } = scan_inner_func(
441+ { inner_in_args }
442+ )
437443{ indent (inner_out_post_processing_block , " " * 8 )}
438444{ indent (inner_out_to_outer_out_stmts , " " * 8 )}
439445 i += 1
440446
441447{ indent (output_storage_post_processing_block , " " * 4 )}
442448
443- return { create_arg_string (outer_output_names )}
449+ return { ", " . join (outer_output_names )}
444450 """
445451
446- global_env = {
447- "np" : np ,
448- "scan_inner_func" : scan_inner_func ,
449- }
450-
451452 scan_op_fn = compile_numba_function_src (
452453 scan_op_src ,
453454 "scan" ,
454- { ** globals (), ** global_env },
455+ globals () | { "np" : np , "scan_inner_func" : scan_inner_func },
455456 )
456457
457458 if inner_func_cache_key is None :
0 commit comments