-
Notifications
You must be signed in to change notification settings - Fork 149
Numba and JAX Scan fixes #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numba and JAX Scan fixes #1754
Conversation
|
Also fixed mit-mot for JAX The other issue with zeroing out shoudn't be a concern yet because AFAICT it only shows up in while Scan and those can't be transpiled to JAX And views are not an issue in JAX as it handles inplacing by itself |
c0db447 to
6de81d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes several issues related to Numba scan operations, particularly addressing memory aliasing problems when scan outputs reference the same underlying memory buffers. The main focus is on ensuring proper handling of MIT-MOT (multiple input taps, multiple output taps) operations with different tap patterns.
Key changes:
- Adds
insert_deepcopycalls to break output aliasing in Numba scan and OpFromGraph dispatch - Introduces normalized tap slice properties to handle both negative and positive taps consistently
- Adds zero-out logic for unused buffer entries in truncated gradient scenarios
- Refactors test coverage for aliased inner outputs with improved documentation
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
tests/scan/test_basic.py |
Refactored test_aliased_inner_outputs with better documentation and parametrization (currently limited to static shapes only) |
tests/link/numba/test_scan.py |
Added tests for higher-order derivatives and grad_until with truncate_sequence_taps |
tests/link/numba/test_compile_ops.py |
Added test for OpFromGraph aliased outputs to verify deepcopy prevents memory aliasing with inplace operations |
tests/link/jax/test_scan.py |
Added higher-order derivatives test for JAX backend |
pytensor/scan/op.py |
Added normalized_mit_mot_in_slices and normalized_mit_mot_out_slices properties to normalize taps as offsets from the oldest tap |
pytensor/link/numba/dispatch/scan.py |
Major refactoring: removed negative offset handling, added deepcopy insertion, improved buffer management with zero-out logic for unused entries, and better nit-sot static shape handling |
pytensor/link/numba/dispatch/compile_ops.py |
Added deepcopy insertion for OpFromGraph to prevent output aliasing |
pytensor/link/jax/dispatch/scan.py |
Updated to use normalized tap slices for consistent tap handling |
pytensor/link/numba/dispatch/basic.py |
Simplified create_tuple_string and removed create_arg_string function |
pytensor/link/numba/dispatch/shape.py |
Replaced create_arg_string usage with direct string join |
6de81d6 to
f9fceb0
Compare
|
Ugh found out that while jax/numba implementation now match the python/cython, there's an issue with the third derivative in the new test case for all backends (including the old ones): #1772 I commented out the assert with a fixme note. The changes in this PR are still meaningful, in that they match the behavior of Scan as is, even if it may be nonsensical. Or the bug may come from elsewhere. Note that before this PR the second derivative in the test was also failing, and now works. And the tests related to the other behaviors are also passing now. So I'm okay with merging this PR without blocking with #1772 as it is not directly caused by these changes. |
|
You need to tweak the tolerances in |
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
Also simplified test. Shared variables aren't needed for the test and clobber it
f9fceb0 to
59a67f4
Compare

Spinoff from #811
Holding off just because there's a failing test when comparing the the jax scan benchmarksAlso fix alias in OFG