Skip to content

Conversation

@emekaokoli19
Copy link
Contributor

Description

This PR fixes an issue in the IfElse numba, where outputs were returned as direct references to the input arrays instead of copies. This violated the semantics of ifelse, which guarantees that the returned value is a distinct object, even when both branches reference the same input.

The fix ensures that each selected output is explicitly copied. This matches the behavior of the Python linker and prevents unexpected mutations when the NumPy arrays are assumed to be non-shared.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice trick with the list, I was thinking we would need to use codegen due to numba limitations.

Left some comments.

We also need to work on the existing tests so they would have failed before the fix and pass now. We have to test both single and multi output and inplace or not

# Return a tuple of copies
out = [None] * n_outs
for i in range(n_outs):
out[i] = selected[i].copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only make a copy if not op.inplace

def ifelse(cond, *args):
if cond:
res = args[:n_outs]
arr = args[0]
Copy link
Member

@ricardoV94 ricardoV94 Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can be simplified to have signature ifelse(cond, if_true, if_false), without need for indexing internally.

It's unrelated to the copy change

# Return a copy
return arr.copy()

return ifelse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return ifelse
cache_version = 1
return ifelse, cache_version

We need to tell PyTensor the implementation of ifelse changed to invalidate old caches

@ricardoV94
Copy link
Member

The list trick didn't work. I suspect that would happen.

We need to use codegen for the non inplace version of ifelse with multiple outputs. Other cases will work fine.

You can see some cases of codegen for the numba funcify Alloc or SpecifyShape that may give some ideas. This tends to happen everytime we would want (*args) in the signature of a numba func

@ricardoV94 ricardoV94 changed the title fix-make copies of inputs in numba Fix non-inplace IfElse on numba mode Dec 3, 2025
@emekaokoli19
Copy link
Contributor Author

The list trick didn't work. I suspect that would happen.

We need to use codegen for the non inplace version of ifelse with multiple outputs. Other cases will work fine.

You can see some cases of codegen for the numba funcify Alloc or SpecifyShape that may give some ideas. This tends to happen everytime we would want (*args) in the signature of a numba func

Hey @ricardoV94, I have added codegen. A lot of tests are failing in tests/link/numba, so I wrote some new tests to test the ifelse changes. Are the failing tests in tests/link/numba a result of something else?

Comment on lines 109 to 114
if n_outs == 1:

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
def ifelse(cond, x_true, x_false):
arr = x_true if cond else x_false
return arr if as_view else arr.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I guess we can get rid of the special case and stay with the codegen for every case now

ifelse_numba = numba_basic.numba_njit(ifelse_py)

return res[0]
cache_version = 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably needed to bump a few times, but for the PR we should only bump once. You can erase your previous cache with pytensor-cache purge for local testing

Suggested change
cache_version = 3
cache_version = 1

assert r2 is not b


def test_ifelse_false_branch():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can merge this test with the previous ones. Just eval the function twice, in a way that triggers the different branches.

y = pt.vector("y")
out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x))

fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrize this and the single output test to test inplace and not inplace. You can create IfElse inplace manually like IfElse(as_view=True|False, n_outs=2), and pass accept_inplace=Truetofunction`.

We want to make sure that r1 is a, r2 is b in that case. Right now we are never testing the inplace mode

@emekaokoli19
Copy link
Contributor Author

@ricardoV94

Explanation for failing is checks in Numba mode tests

The tests use assert res is a to check if the output is the same Python object as the input. This fails under Numba and I think it is because:

  1. IfElse with as_view=True normally returns the same object in Python mode.
  2. In Numba mode, the compiled function creates new arrays for outputs, even when as_view=True.
  3. Therefore, res is a fails because object identity is different, even though the array values are identical.

Should we replace is checks with np.array_equal, which tests element-wise equality instead of memory identity? However, this would mean that as_view would be redundant with numba mode

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 4, 2025

Ah you need to set the borrow flags to allow pytensor to pass back inputs unalterated:

import numpy as np
import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
fn = pytensor.function([pytensor.In(x, borrow=True)], pytensor.Out(x, borrow=True), mode="NUMBA")
x_test = np.zeros(5)
fn(x_test) is x_test

You can check that without that the final function would have a deepcopy Op, using fn.dprint()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Numba ifelse does not copy inputs when not inplace

2 participants