@@ -3194,38 +3194,66 @@ def onestep(x, x_tm4):
31943194 f = function ([seq ], results [1 ])
31953195 assert np .all (exp_out == f (inp ))
31963196
3197- def test_shared_borrow (self ):
3197+ @pytest .mark .parametrize ("static_shape" , (True , False )[:1 ])
3198+ def test_aliased_inner_outputs (self , static_shape ):
31983199 """
3199- This tests two things. The first is a bug occurring when scan wrongly
3200- used the borrow flag. The second thing it that Scan's infer_shape()
3201- method will be able to remove the Scan node from the graph in this
3202- case.
3200+ This tests two things. The first is a bug occurring when scan wrongly
3201+ used the borrow flag. The second thing it that Scan's infer_shape()
3202+ method will be able to remove the Scan node from the graph in this
3203+ case.
3204+
3205+ Here is pure python equivalent of the problem we want to avoid:
3206+ ```python
3207+ def scan(seq, initval):
3208+ # Due to memory optimization we override values of mitsot as we iterate
3209+ # That's why mitsot has shape (4, 1) and not (14, 1)
3210+ mitsot = np.zeros((4, 1))
3211+ mitsot[:4] = initval
3212+ nitsot = np.zeros((10, 1))
3213+ for i, s in enumerate(seq):
3214+ # Incorrect results
3215+ mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
3216+ # Correct results
3217+ # mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
3218+
3219+ return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
3220+
3221+ scan(np.arange(10), np.zeros((4, 1)))
3222+ ```
32033223 """
32043224
3205- inp = np .arange (10 ).reshape (- 1 , 1 ).astype (config .floatX )
3206- exp_out = np .zeros ((10 , 1 )).astype (config .floatX )
3207- exp_out [4 :] = inp [:- 4 ]
3208-
3209- def onestep (x , x_tm4 ):
3210- return x , x_tm4
3211-
3212- seq = matrix ()
3213- initial_value = shared (np .zeros ((4 , 1 ), dtype = config .floatX ))
3214- outputs_info = [{"initial" : initial_value , "taps" : [- 4 ]}, None ]
3215- results = scan (
3216- fn = onestep , sequences = seq , outputs_info = outputs_info , return_updates = False
3225+ def onestep (seq , seq_tm4 ):
3226+ # Recurring output is just each value of seq
3227+ # And we further map the tap -4 as a new output
3228+ return seq , seq_tm4
3229+
3230+ # Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
3231+ # Otherwise we would be working with scalars and memory alias wouldn't be a concern
3232+ seq = matrix (shape = (10 , 1 ) if static_shape else (None , None ), name = "seq" )
3233+ init = matrix (shape = (4 , 1 ) if static_shape else (None , None ), name = "init" )
3234+ outputs_info = [{"initial" : init , "taps" : [- 4 ]}, None ]
3235+ [out_seq , out_seq_tm4 ] = scan (
3236+ fn = onestep ,
3237+ sequences = seq ,
3238+ outputs_info = outputs_info ,
3239+ return_updates = False ,
32173240 )
3218- sharedvar = shared (np .zeros ((1 , 1 ), dtype = config .floatX ))
3219- updates = {sharedvar : results [0 ][- 1 :]}
32203241
3221- f = function ([seq ], results [ 1 ], updates = updates )
3242+ f = function ([seq , init ], [ out_seq [ - 1 ]. ravel (), out_seq_tm4 . ravel ()] )
32223243
3223- # This fails if scan uses wrongly the borrow flag
3224- assert np .all (exp_out == f (inp ))
3244+ seq_test_val = np .arange (10 , dtype = config .floatX )[:, None ]
3245+ init_test_val = np .zeros ((4 , 1 ), dtype = config .floatX )
3246+
3247+ res0 , res1 = f (seq_test_val , init_test_val )
3248+ expected_res0 = np .array ([9 ], dtype = config .floatX )
3249+ expected_res1 = np .zeros (10 , dtype = config .floatX )
3250+ expected_res1 [4 :] = np .arange (6 )
3251+ np .testing .assert_array_equal (res0 , expected_res0 )
3252+ np .testing .assert_array_equal (res1 , expected_res1 )
32253253
32263254 # This fails if Scan's infer_shape() is unable to remove the Scan
32273255 # node from the graph.
3228- f_infershape = function ([seq ], results [1 ].shape , mode = "FAST_RUN" )
3256+ f_infershape = function ([seq , init ], out_seq_tm4 [1 ].shape )
32293257 scan_nodes_infershape = scan_nodes_from_fct (f_infershape )
32303258 assert len (scan_nodes_infershape ) == 0
32313259
0 commit comments