summaryrefslogtreecommitdiff
path: root/numpy/array_api/_set_functions.py
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
committerRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
commiteccb8dfbd9b07183e16a1144e8d5d76936671bfc (patch)
tree647a9477b4f3b8b7205f2f7f2feb99eaa482e806 /numpy/array_api/_set_functions.py
parentd0d75f39f28ac26d4cc1aa3a4cbea63a6a027929 (diff)
parentff2e2a1e7eea29d925063b13922e096d14331222 (diff)
downloadnumpy-eccb8dfbd9b07183e16a1144e8d5d76936671bfc.tar.gz
Merge branch 'main' into never_copy
Diffstat (limited to 'numpy/array_api/_set_functions.py')
-rw-r--r--numpy/array_api/_set_functions.py89
1 files changed, 75 insertions, 14 deletions
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index 357f238f5..05ee7e555 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -2,19 +2,82 @@ from __future__ import annotations
from ._array_object import Array
-from typing import Tuple, Union
+from typing import NamedTuple
import numpy as np
+# Note: np.unique() is split into four functions in the array API:
+# unique_all, unique_counts, unique_inverse, and unique_values (this is done
+# to remove polymorphic return types).
-def unique(
- x: Array,
- /,
- *,
- return_counts: bool = False,
- return_index: bool = False,
- return_inverse: bool = False,
-) -> Union[Array, Tuple[Array, ...]]:
+# Note: The various unique() functions are supposed to return multiple NaNs.
+# This does not match the NumPy behavior, however, this is currently left as a
+# TODO in this implementation as this behavior may be reverted in np.unique().
+# See https://github.com/numpy/numpy/issues/20326.
+
+# Note: The functions here return a namedtuple (np.unique() returns a normal
+# tuple).
+
+class UniqueAllResult(NamedTuple):
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
+
+
+class UniqueCountsResult(NamedTuple):
+ values: Array
+ counts: Array
+
+
+class UniqueInverseResult(NamedTuple):
+ values: Array
+ inverse_indices: Array
+
+
+def unique_all(x: Array, /) -> UniqueAllResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=True,
+ return_inverse=True,
+ )
+
+ return UniqueAllResult(*[Array._new(i) for i in res])
+
+
+def unique_counts(x: Array, /) -> UniqueCountsResult:
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=False,
+ return_inverse=False,
+ )
+
+ return UniqueCountsResult(*[Array._new(i) for i in res])
+
+
+def unique_inverse(x: Array, /) -> UniqueInverseResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=True,
+ )
+ return UniqueInverseResult(*[Array._new(i) for i in res])
+
+
+def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -22,10 +85,8 @@ def unique(
"""
res = np.unique(
x._array,
- return_counts=return_counts,
- return_index=return_index,
- return_inverse=return_inverse,
+ return_counts=False,
+ return_index=False,
+ return_inverse=False,
)
- if isinstance(res, tuple):
- return tuple(Array._new(i) for i in res)
return Array._new(res)