Skip to content

Commit ec57415

Browse files
committed
Numba UnravelIndex: Handle arbitrary indices ndim and F-order
1 parent 4186819 commit ec57415

File tree

3 files changed

+71
-42
lines changed

3 files changed

+71
-42
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -261,41 +261,60 @@ def unique(x):
261261

262262
@register_funcify_and_cache_key(UnravelIndex)
263263
def numba_funcify_UnravelIndex(op, node, **kwargs):
264-
order = op.order
265-
266-
if order != "C":
267-
raise NotImplementedError(
268-
"Numba does not support the `order` argument in `numpy.unravel_index`"
269-
)
264+
out_ndim = node.outputs[0].type.ndim
270265

271-
if len(node.outputs) == 1:
272-
273-
@numba_basic.numba_njit(inline="always")
274-
def maybe_expand_dim(arr):
275-
return arr
276-
277-
else:
266+
if out_ndim == 0:
267+
# Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode
268+
return generate_fallback_impl(op, node=node), None
278269

279-
@numba_basic.numba_njit(inline="always")
280-
def maybe_expand_dim(arr):
281-
return np.expand_dims(arr, 1)
270+
c_order = op.order == "C"
271+
inp_ndim = node.inputs[0].type.ndim
272+
transpose_axes = (inp_ndim, *range(inp_ndim))
282273

283274
@numba_basic.numba_njit
284-
def unravelindex(arr, shape):
275+
def unravelindex(indices, shape):
285276
a = np.ones(len(shape), dtype=np.int64)
286-
a[1:] = shape[:0:-1]
287-
a = np.cumprod(a)[::-1]
277+
if c_order:
278+
# C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back
279+
# Strides: [dim1*dim2, dim2, 1]
280+
a[1:] = shape[:0:-1]
281+
a = np.cumprod(a)[::-1]
282+
else:
283+
# F-Order: Standard shape, cumulative product
284+
# Strides: [1, dim0, dim0*dim1]
285+
a[1:] = shape[:-1]
286+
a = np.cumprod(a)
287+
288+
# Broadcast with a and shape on the last axis
289+
unraveled_coords = (indices[..., None] // a) % shape
288290

289-
# PyTensor actually returns a `tuple` of these values, instead of an
290-
# `ndarray`; however, this `ndarray` result should be able to be
291-
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
292-
return ((maybe_expand_dim(arr) // a) % shape).T
291+
# Then transpose it to the front
292+
# Numba doesn't have moveaxis (why would it), so we use transpose
293+
# res = np.moveaxis(res, -1, 0)
294+
unraveled_coords = unraveled_coords.transpose(transpose_axes)
293295

296+
# This should be a tuple, but the array can be unpacked
297+
# into multiple variables with the same effect by the outer function
298+
# (special case for single entry is handled with an outer function below)
299+
return unraveled_coords
300+
301+
cache_version = 1
294302
cache_key = sha256(
295-
str((type(op), op.order, len(node.outputs))).encode()
303+
str((type(op), op.order, len(node.outputs), cache_version)).encode()
296304
).hexdigest()
297305

298-
return unravelindex, cache_key
306+
if len(node.outputs) == 1:
307+
308+
@numba_basic.numba_njit
309+
def unravel_index_single_item(arr, shape):
310+
# Unpack single entry
311+
(res,) = unravelindex(arr, shape)
312+
return res
313+
314+
return unravel_index_single_item, cache_key
315+
316+
else:
317+
return unravelindex, cache_key
299318

300319

301320
@register_funcify_default_op_cache_key(SearchsortedOp)

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,13 +1304,11 @@ def make_node(self, indices, dims):
13041304
if dims.ndim != 1:
13051305
raise TypeError("dims must be a 1D array")
13061306

1307+
out_type = indices.type.clone(dtype="int64")
13071308
return Apply(
13081309
self,
13091310
[indices, dims],
1310-
[
1311-
TensorType(dtype="int64", shape=(None,) * indices.type.ndim)()
1312-
for i in range(ptb.get_vector_length(dims))
1313-
],
1311+
[out_type() for _i in range(ptb.get_vector_length(dims))],
13141312
)
13151313

13161314
def infer_shape(self, fgraph, node, input_shapes):

tests/link/numba/test_extra_ops.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
from contextlib import nullcontext
23

34
import numpy as np
45
import pytest
@@ -295,37 +296,48 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
295296

296297

297298
@pytest.mark.parametrize(
298-
"arr, shape, order, exc",
299+
"arr, shape, requires_obj_mode",
299300
[
301+
(
302+
(pt.lscalar(), np.array(9, dtype="int64")),
303+
pt.as_tensor([2, 3, 4]),
304+
True,
305+
),
300306
(
301307
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
302308
pt.as_tensor([2, 3, 4]),
303-
"C",
304-
None,
309+
False,
305310
),
306311
(
307312
(pt.lvector(), np.array([1, 0], dtype="int64")),
308313
pt.as_tensor([2]),
309-
"C",
310-
None,
314+
False,
311315
),
312316
(
313-
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
317+
(pt.lmatrix(), np.array([[9, 15, 1], [1, 9, 15]], dtype="int64")),
314318
pt.as_tensor([2, 3, 4]),
315-
"F",
316-
NotImplementedError,
319+
False,
317320
),
318321
],
319322
)
320-
def test_UnravelIndex(arr, shape, order, exc):
323+
def test_UnravelIndex(arr, shape, requires_obj_mode):
321324
arr, test_arr = arr
322-
g = extra_ops.UnravelIndex(order)(arr, shape)
323-
324-
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
325+
g_c = extra_ops.UnravelIndex("C")(arr, shape)
326+
g_f = extra_ops.UnravelIndex("F")(arr, shape)
327+
if shape.type.shape == (1,):
328+
outputs = [g_c, g_f]
329+
else:
330+
outputs = [*g_c, *g_f]
331+
332+
cm = (
333+
pytest.warns(UserWarning, match="object mode")
334+
if requires_obj_mode
335+
else nullcontext()
336+
)
325337
with cm:
326338
compare_numba_and_py(
327339
[arr],
328-
g,
340+
outputs,
329341
[test_arr],
330342
)
331343

0 commit comments

Comments
 (0)