Skip to content

Commit 8d7c528

Browse files
committed
Numba typify: Use more robust ListType
1 parent 04b0137 commit 8d7c528

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_numba_type(
134134
elif isinstance(pytensor_type, RandomGeneratorType):
135135
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
136136
elif isinstance(pytensor_type, TypedListType):
137-
return numba.types.List(get_numba_type(pytensor_type.ttype))
137+
return numba.types.ListType(get_numba_type(pytensor_type.ttype))
138138
else:
139139
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
140140

pytensor/link/numba/dispatch/typed_list.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numba
22
import numpy as np
3-
from numba.types import Array, Boolean, List, Number
3+
from numba.types import Array, Boolean, List, ListType, Number
44

55
import pytensor.link.numba.dispatch.basic as numba_basic
66
from pytensor.link.numba.dispatch.basic import (
@@ -42,7 +42,7 @@ def numba_all_equal(x, y):
4242
def list_all_equal(x, y):
4343
all_equal = None
4444

45-
if isinstance(x, List) and isinstance(y, List):
45+
if isinstance(x, List | ListType) and isinstance(y, List | ListType):
4646

4747
def all_equal(x, y):
4848
if len(x) != len(y):
@@ -69,7 +69,7 @@ def all_equal(x, y):
6969

7070
@numba.extending.overload(numba_deepcopy)
7171
def numba_deepcopy_list(x):
72-
if isinstance(x, List):
72+
if isinstance(x, List | ListType):
7373

7474
def deepcopy_list(x):
7575
return [numba_deepcopy(xi) for xi in x]

0 commit comments

Comments
 (0)