Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 27, 2025

Spinoff from #811

Holding off just because there's a failing test when comparing the the jax scan benchmarks

Also fix alias in OFG

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 29, 2025

Also fixed mit-mot for JAX

The other issue with zeroing out shoudn't be a concern yet because AFAICT it only shows up in while Scan and those can't be transpiled to JAX

And views are not an issue in JAX as it handles inplacing by itself

@ricardoV94 ricardoV94 marked this pull request as ready for review December 4, 2025 16:34
Copilot finished reviewing on behalf of ricardoV94 December 4, 2025 17:18
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes several issues related to Numba scan operations, particularly addressing memory aliasing problems when scan outputs reference the same underlying memory buffers. The main focus is on ensuring proper handling of MIT-MOT (multiple input taps, multiple output taps) operations with different tap patterns.

Key changes:

  • Adds insert_deepcopy calls to break output aliasing in Numba scan and OpFromGraph dispatch
  • Introduces normalized tap slice properties to handle both negative and positive taps consistently
  • Adds zero-out logic for unused buffer entries in truncated gradient scenarios
  • Refactors test coverage for aliased inner outputs with improved documentation

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/scan/test_basic.py Refactored test_aliased_inner_outputs with better documentation and parametrization (currently limited to static shapes only)
tests/link/numba/test_scan.py Added tests for higher-order derivatives and grad_until with truncate_sequence_taps
tests/link/numba/test_compile_ops.py Added test for OpFromGraph aliased outputs to verify deepcopy prevents memory aliasing with inplace operations
tests/link/jax/test_scan.py Added higher-order derivatives test for JAX backend
pytensor/scan/op.py Added normalized_mit_mot_in_slices and normalized_mit_mot_out_slices properties to normalize taps as offsets from the oldest tap
pytensor/link/numba/dispatch/scan.py Major refactoring: removed negative offset handling, added deepcopy insertion, improved buffer management with zero-out logic for unused entries, and better nit-sot static shape handling
pytensor/link/numba/dispatch/compile_ops.py Added deepcopy insertion for OpFromGraph to prevent output aliasing
pytensor/link/jax/dispatch/scan.py Updated to use normalized tap slices for consistent tap handling
pytensor/link/numba/dispatch/basic.py Simplified create_tuple_string and removed create_arg_string function
pytensor/link/numba/dispatch/shape.py Replaced create_arg_string usage with direct string join

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 4, 2025

Ugh found out that while jax/numba implementation now match the python/cython, there's an issue with the third derivative in the new test case for all backends (including the old ones): #1772

I commented out the assert with a fixme note. The changes in this PR are still meaningful, in that they match the behavior of Scan as is, even if it may be nonsensical. Or the bug may come from elsewhere. Note that before this PR the second derivative in the test was also failing, and now works. And the tests related to the other behaviors are also passing now.

So I'm okay with merging this PR without blocking with #1772 as it is not directly caused by these changes.

@ricardoV94 ricardoV94 changed the title Numba scan fixes Numba and JAX Scan fixes Dec 4, 2025
@jessegrabowski
Copy link
Member

You need to tweak the tolerances in test_higher_order_derivatives for half precision.

@ricardoV94
Copy link
Member Author

You need to tweak the tolerances in test_higher_order_derivatives for half precision.

image

Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
Also simplified test. Shared variables aren't needed for the test and clobber it
@ricardoV94 ricardoV94 merged commit 4186819 into pymc-devs:main Dec 5, 2025
56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants