diff options
author | scoder <stefan_ml@behnel.de> | 2023-05-04 09:29:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-04 09:29:53 +0200 |
commit | 442c8f48d3146ec32c7d5387310e171276cf10ac (patch) | |
tree | d8911d1a64e384b7955d3fc09a07edd218a9f1ee /numpy/array_api/_indexing_functions.py | |
parent | 3e4a6cba2da27bbe2a6e12c163238e503c9f6a07 (diff) | |
parent | 9163e933df91b516b6f0c7a9ba8ad1750e642f37 (diff) | |
download | numpy-442c8f48d3146ec32c7d5387310e171276cf10ac.tar.gz |
Merge branch 'main' into cython3_noexcept
Diffstat (limited to 'numpy/array_api/_indexing_functions.py')
-rw-r--r-- | numpy/array_api/_indexing_functions.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py new file mode 100644 index 000000000..ba56bcd6f --- /dev/null +++ b/numpy/array_api/_indexing_functions.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _integer_dtypes + +import numpy as np + +def take(x: Array, indices: Array, /, *, axis: int) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take <numpy.take>`. + + See its docstring for more information. + """ + if indices.dtype not in _integer_dtypes: + raise TypeError("Only integer dtypes are allowed in indexing") + if indices.ndim != 1: + raise ValueError("Only 1-dim indices array is supported") + return Array._new(np.take(x._array, indices._array, axis=axis)) |