diff options
author | Matti Picus <matti.picus@gmail.com> | 2023-02-27 15:05:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-27 15:05:01 +0200 |
commit | 68ddb26260946f6f73c48a781ca0f8ce4191898e (patch) | |
tree | bce8b7d24634bfdbd71bb49bdac8afd8cd3971d6 | |
parent | e49b744d2c4e78cf1c3170e5d53de07da0cfb71c (diff) | |
parent | 786bd366b22675b1e4067653d4729031c258f35f (diff) | |
download | numpy-68ddb26260946f6f73c48a781ca0f8ce4191898e.tar.gz |
Merge pull request #23284 from arogozhnikov/add-xp-take
ENH: add support for xp.take
-rw-r--r-- | numpy/array_api/__init__.py | 4 | ||||
-rw-r--r-- | numpy/array_api/_indexing_functions.py | 18 | ||||
-rw-r--r-- | numpy/array_api/tests/test_indexing_functions.py | 24 |
3 files changed, 46 insertions, 0 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 5e58ee0a8..e154b9952 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -333,6 +333,10 @@ __all__ += [ "trunc", ] +from ._indexing_functions import take + +__all__ += ["take"] + # linalg is an extension in the array API spec, which is a sub-namespace. Only # a subset of functions in it are imported into the top-level namespace. from . import linalg diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py new file mode 100644 index 000000000..ba56bcd6f --- /dev/null +++ b/numpy/array_api/_indexing_functions.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _integer_dtypes + +import numpy as np + +def take(x: Array, indices: Array, /, *, axis: int) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take <numpy.take>`. + + See its docstring for more information. + """ + if indices.dtype not in _integer_dtypes: + raise TypeError("Only integer dtypes are allowed in indexing") + if indices.ndim != 1: + raise ValueError("Only 1-dim indices array is supported") + return Array._new(np.take(x._array, indices._array, axis=axis)) diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py new file mode 100644 index 000000000..9e05c6386 --- /dev/null +++ b/numpy/array_api/tests/test_indexing_functions.py @@ -0,0 +1,24 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "x, indices, axis, expected", + [ + ([2, 3], [1, 1, 0], 0, [3, 3, 2]), + ([2, 3], [1, 1, 0], -1, [3, 3, 2]), + ([[2, 3]], [1], -1, [[3]]), + ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]), + ], +) +def test_take_function(x, indices, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(x) + indices = xp.asarray(indices) + out = xp.take(x, indices, axis=axis) + assert xp.all(out == xp.asarray(expected)) |