-
Notifications
You must be signed in to change notification settings - Fork 149
Fix non-inplace IfElse on numba mode #1765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix non-inplace IfElse on numba mode #1765
Conversation
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice trick with the list, I was thinking we would need to use codegen due to numba limitations.
Left some comments.
We also need to work on the existing tests so they would have failed before the fix and pass now. We have to test both single and multi output and inplace or not
| # Return a tuple of copies | ||
| out = [None] * n_outs | ||
| for i in range(n_outs): | ||
| out[i] = selected[i].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only make a copy if not op.inplace
| def ifelse(cond, *args): | ||
| if cond: | ||
| res = args[:n_outs] | ||
| arr = args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case can be simplified to have signature ifelse(cond, if_true, if_false), without need for indexing internally.
It's unrelated to the copy change
| # Return a copy | ||
| return arr.copy() | ||
|
|
||
| return ifelse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return ifelse | |
| cache_version = 1 | |
| return ifelse, cache_version |
We need to tell PyTensor the implementation of ifelse changed to invalidate old caches
|
The list trick didn't work. I suspect that would happen. We need to use codegen for the non inplace version of ifelse with multiple outputs. Other cases will work fine. You can see some cases of codegen for the numba funcify |
Hey @ricardoV94, I have added codegen. A lot of tests are failing in |
| if n_outs == 1: | ||
|
|
||
| @numba_basic.numba_njit | ||
| def ifelse(cond, *args): | ||
| if cond: | ||
| res = args[:n_outs] | ||
| else: | ||
| res = args[n_outs:] | ||
| def ifelse(cond, x_true, x_false): | ||
| arr = x_true if cond else x_false | ||
| return arr if as_view else arr.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought I guess we can get rid of the special case and stay with the codegen for every case now
| ifelse_numba = numba_basic.numba_njit(ifelse_py) | ||
|
|
||
| return res[0] | ||
| cache_version = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably needed to bump a few times, but for the PR we should only bump once. You can erase your previous cache with pytensor-cache purge for local testing
| cache_version = 3 | |
| cache_version = 1 |
tests/link/numba/test_compile_ops.py
Outdated
| assert r2 is not b | ||
|
|
||
|
|
||
| def test_ifelse_false_branch(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can merge this test with the previous ones. Just eval the function twice, in a way that triggers the different branches.
tests/link/numba/test_compile_ops.py
Outdated
| y = pt.vector("y") | ||
| out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x)) | ||
|
|
||
| fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parametrize this and the single output test to test inplace and not inplace. You can create IfElse inplace manually like IfElse(as_view=True|False, n_outs=2), and pass accept_inplace=Truetofunction`.
We want to make sure that r1 is a, r2 is b in that case. Right now we are never testing the inplace mode
Explanation for failing
|
|
Ah you need to set the borrow flags to allow pytensor to pass back inputs unalterated: import numpy as np
import pytensor
import pytensor.tensor as pt
x = pt.vector("x")
fn = pytensor.function([pytensor.In(x, borrow=True)], pytensor.Out(x, borrow=True), mode="NUMBA")
x_test = np.zeros(5)
fn(x_test) is x_testYou can check that without that the final function would have a deepcopy Op, using |
Description
This PR fixes an issue in the
IfElsenumba, where outputs were returned as direct references to the input arrays instead of copies. This violated the semantics ofifelse, which guarantees that the returned value is a distinct object, even when both branches reference the same input.The fix ensures that each selected output is explicitly copied. This matches the behavior of the Python linker and prevents unexpected mutations when the NumPy arrays are assumed to be non-shared.
Related Issue
Checklist
Type of change