Skip to content

Commit a782753

Browse files
committed
Numba RavelMultiIndex: Handle arbitrary indices ndim and F-order
1 parent ec57415 commit a782753

File tree

3 files changed

+75
-75
lines changed

3 files changed

+75
-75
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -135,65 +135,49 @@ def filldiagonaloffset(a, val, offset):
135135
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
136136
mode = op.mode
137137
order = op.order
138+
vec_indices = node.inputs[0].type.ndim > 0
138139

139-
if order != "C":
140-
raise NotImplementedError(
141-
"Numba does not implement `order` in `numpy.ravel_multi_index`"
142-
)
143-
144-
if mode == "raise":
145-
146-
@numba_basic.numba_njit
147-
def mode_fn(*args):
148-
raise ValueError("invalid entry in coordinates array")
149-
150-
elif mode == "wrap":
151-
152-
@numba_basic.numba_njit(inline="always")
153-
def mode_fn(new_arr, i, j, v, d):
154-
new_arr[i, j] = v % d
155-
156-
elif mode == "clip":
157-
158-
@numba_basic.numba_njit(inline="always")
159-
def mode_fn(new_arr, i, j, v, d):
160-
new_arr[i, j] = min(max(v, 0), d - 1)
161-
162-
if node.inputs[0].ndim == 0:
163-
164-
@numba_basic.numba_njit
165-
def ravelmultiindex(*inp):
166-
shape = inp[-1]
167-
arr = np.stack(inp[:-1])
168-
169-
new_arr = arr.T.astype(np.float64).copy()
170-
for i, b in enumerate(new_arr):
171-
if b < 0 or b >= shape[i]:
172-
mode_fn(new_arr, i, 0, b, shape[i])
173-
174-
a = np.ones(len(shape), dtype=np.float64)
175-
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
176-
return np.array(a.dot(new_arr.T), dtype=np.int64)
177-
178-
else:
140+
@numba_basic.numba_njit
141+
def ravelmultiindex(*inp):
142+
shape = inp[-1]
143+
# Concatenate indices along last axis
144+
stacked_indices = np.stack(inp[:-1], axis=-1)
145+
146+
# Manage invalid indices
147+
for i, dim_limit in enumerate(shape):
148+
if mode == "wrap":
149+
stacked_indices[..., i] %= dim_limit
150+
elif mode == "clip":
151+
dim_indices = stacked_indices[..., i]
152+
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
153+
else: # raise
154+
dim_indices = stacked_indices[..., i]
155+
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
156+
# Cannot call np.any on a boolean
157+
if vec_indices:
158+
invalid_indices = invalid_indices.any()
159+
if invalid_indices:
160+
raise ValueError("invalid entry in coordinates array")
161+
162+
# Calculate Strides based on Order
163+
a = np.ones(len(shape), dtype=np.int64)
164+
if order == "C":
165+
# C-Order: Last dimension moves fastest (Strides: large -> small -> 1)
166+
# For shape (3, 4, 5): Multipliers are (20, 5, 1)
167+
if len(shape) > 1:
168+
a[:-1] = np.cumprod(shape[:0:-1])[::-1]
169+
else: # order == "F"
170+
# F-Order: First dimension moves fastest (Strides: 1 -> small -> large)
171+
# For shape (3, 4, 5): Multipliers are (1, 3, 12)
172+
if len(shape) > 1:
173+
a[1:] = np.cumprod(shape[:-1])
174+
175+
# Dot product indices with strides
176+
# (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support)
177+
return np.asarray((stacked_indices * a).sum(-1))
179178

180-
@numba_basic.numba_njit
181-
def ravelmultiindex(*inp):
182-
shape = inp[-1]
183-
arr = np.stack(inp[:-1])
184-
185-
new_arr = arr.T.astype(np.float64).copy()
186-
for i, b in enumerate(new_arr):
187-
# no strict argument to this zip because numba doesn't support it
188-
for j, (d, v) in enumerate(zip(shape, b)):
189-
if v < 0 or v >= d:
190-
mode_fn(new_arr, i, j, v, d)
191-
192-
a = np.ones(len(shape), dtype=np.float64)
193-
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
194-
return a.dot(new_arr.T).astype(np.int64)
195-
196-
return ravelmultiindex
179+
cache_version = 1
180+
return ravelmultiindex, cache_version
197181

198182

199183
@register_funcify_default_op_cache_key(Repeat)

pytensor/tensor/extra_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,8 +1371,7 @@ def __init__(self, mode="raise", order="C"):
13711371
self.order = order
13721372

13731373
def make_node(self, *inp):
1374-
multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]]
1375-
dims = ptb.as_tensor_variable(inp[-1])
1374+
*multi_index, dims = map(ptb.as_tensor_variable, inp)
13761375

13771376
for i in multi_index:
13781377
if i.dtype not in int_dtypes:
@@ -1382,19 +1381,20 @@ def make_node(self, *inp):
13821381
if dims.ndim != 1:
13831382
raise TypeError("dims must be a 1D array")
13841383

1384+
out_type = multi_index[0].type.clone(dtype="int64")
13851385
return Apply(
13861386
self,
13871387
[*multi_index, dims],
1388-
[TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()],
1388+
[out_type()],
13891389
)
13901390

13911391
def infer_shape(self, fgraph, node, input_shapes):
13921392
return [input_shapes[0]]
13931393

13941394
def perform(self, node, inp, out):
1395-
multi_index, dims = inp[:-1], inp[-1]
1395+
*multi_index, dims = inp
13961396
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
1397-
out[0][0] = np.asarray(res, node.outputs[0].dtype)
1397+
out[0][0] = np.asarray(res, "int64")
13981398

13991399

14001400
def ravel_multi_index(multi_index, dims, mode="raise", order="C"):

tests/link/numba/test_extra_ops.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor.tensor as pt
88
from pytensor import config
99
from pytensor.tensor import extra_ops
10+
from pytensor.tensor.extra_ops import RavelMultiIndex
1011
from tests.link.numba.test_basic import compare_numba_and_py
1112

1213

@@ -133,43 +134,41 @@ def test_FillDiagonalOffset(a, val, offset):
133134

134135

135136
@pytest.mark.parametrize(
136-
"arr, shape, mode, order, exc",
137+
"arr, shape, mode, exc",
137138
[
138139
(
139140
tuple((pt.lscalar(), v) for v in np.array([0])),
140141
(pt.lvector(), np.array([2])),
141142
"raise",
142-
"C",
143143
None,
144144
),
145145
(
146146
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
147147
(pt.lvector(), np.array([2, 3, 4])),
148148
"raise",
149-
"C",
150149
None,
151150
),
152151
(
153152
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
154153
(pt.lvector(), np.array([2, 3, 4])),
155154
"raise",
156-
"C",
157155
None,
158156
),
159157
(
160-
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
158+
tuple(
159+
(pt.lmatrix(), np.broadcast_to(v, (3, 2)).copy())
160+
for v in np.array([[0, 1], [2, 0], [1, 3]])
161+
),
161162
(pt.lvector(), np.array([2, 3, 4])),
162163
"raise",
163-
"F",
164-
NotImplementedError,
164+
None,
165165
),
166166
(
167167
tuple(
168168
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
169169
),
170170
(pt.lvector(), np.array([2, 3, 4])),
171171
"raise",
172-
"C",
173172
ValueError,
174173
),
175174
(
@@ -178,7 +177,15 @@ def test_FillDiagonalOffset(a, val, offset):
178177
),
179178
(pt.lvector(), np.array([2, 3, 4])),
180179
"wrap",
181-
"C",
180+
None,
181+
),
182+
(
183+
tuple(
184+
(pt.ltensor3(), np.broadcast_to(v, (2, 2, 3)).copy())
185+
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
186+
),
187+
(pt.lvector(), np.array([2, 3, 4])),
188+
"wrap",
182189
None,
183190
),
184191
(
@@ -187,21 +194,30 @@ def test_FillDiagonalOffset(a, val, offset):
187194
),
188195
(pt.lvector(), np.array([2, 3, 4])),
189196
"clip",
190-
"C",
197+
None,
198+
),
199+
(
200+
tuple(
201+
(pt.lmatrix(), np.broadcast_to(v, (2, 3)).copy())
202+
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
203+
),
204+
(pt.lvector(), np.array([2, 3, 4])),
205+
"clip",
191206
None,
192207
),
193208
],
194209
)
195-
def test_RavelMultiIndex(arr, shape, mode, order, exc):
210+
def test_RavelMultiIndex(arr, shape, mode, exc):
196211
arr, test_arr = zip(*arr, strict=True)
197212
shape, test_shape = shape
198-
g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape)
213+
g_c = RavelMultiIndex(mode, order="C")(*arr, shape)
214+
g_f = RavelMultiIndex(mode, order="F")(*arr, shape)
199215

200216
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
201217
with cm:
202218
compare_numba_and_py(
203219
[*arr, shape],
204-
g,
220+
[g_c, g_f],
205221
[*test_arr, test_shape],
206222
)
207223

0 commit comments

Comments
 (0)