Skip to content

Commit 3cfd6e4

Browse files
committed
Numba overloads: Boolean is not Number
1 parent cc674a1 commit 3cfd6e4

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def numba_deepcopy(x):
2525

2626
@numba.extending.overload(numba_deepcopy)
2727
def numba_deepcopy_tensor(x):
28-
if isinstance(x, numba.types.Number):
28+
if isinstance(x, numba.types.Number | numba.types.Boolean):
2929

3030
def number_deepcopy(x):
3131
return x

pytensor/link/numba/dispatch/typed_list.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numba
22
import numpy as np
3+
from numba.types import Array, Boolean, List, Number
34

45
import pytensor.link.numba.dispatch.basic as numba_basic
56
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
@@ -37,7 +38,7 @@ def numba_all_equal(x, y):
3738
def list_all_equal(x, y):
3839
all_equal = None
3940

40-
if isinstance(x, numba.types.List) and isinstance(y, numba.types.List):
41+
if isinstance(x, List) and isinstance(y, List):
4142

4243
def all_equal(x, y):
4344
if len(x) != len(y):
@@ -47,12 +48,12 @@ def all_equal(x, y):
4748
return False
4849
return True
4950

50-
if isinstance(x, numba.types.Array) and isinstance(y, numba.types.Array):
51+
if isinstance(x, Array) and isinstance(y, Array):
5152

5253
def all_equal(x, y):
5354
return (x == y).all()
5455

55-
if isinstance(x, numba.types.Number) and isinstance(y.numba.types.Number):
56+
if isinstance(x, Number | Boolean) and isinstance(y, Number | Boolean):
5657

5758
def all_equal(x, y):
5859
return x == y
@@ -62,7 +63,7 @@ def all_equal(x, y):
6263

6364
@numba.extending.overload(numba_deepcopy)
6465
def numba_deepcopy_list(x):
65-
if isinstance(x, numba.types.List):
66+
if isinstance(x, List):
6667

6768
def deepcopy_list(x):
6869
return [numba_deepcopy(xi) for xi in x]

0 commit comments

Comments
 (0)