summaryrefslogtreecommitdiff
path: root/numpy/array_api/_set_functions.py
diff options
context:
space:
mode:
authorDeveloper-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com>2021-11-18 14:32:20 -0800
committerDeveloper-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com>2021-11-18 14:32:20 -0800
commit9fe353e048bc317099bb13cae4cb5a4de8c0c656 (patch)
tree4976e4328da9906d99cd2478547ca1e2829e071d /numpy/array_api/_set_functions.py
parent66ca468b97873c1c8bfe69bddac6df633780c71c (diff)
parent5e9ce0c0529e3085498ac892941a020a65c7369a (diff)
downloadnumpy-9fe353e048bc317099bb13cae4cb5a4de8c0c656.tar.gz
Merge branch 'as_min_max' of https://github.com/Developer-Ecosystem-Engineering/numpy into as_min_max
Diffstat (limited to 'numpy/array_api/_set_functions.py')
-rw-r--r--numpy/array_api/_set_functions.py89
1 files changed, 75 insertions, 14 deletions
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index 357f238f5..05ee7e555 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -2,19 +2,82 @@ from __future__ import annotations
from ._array_object import Array
-from typing import Tuple, Union
+from typing import NamedTuple
import numpy as np
+# Note: np.unique() is split into four functions in the array API:
+# unique_all, unique_counts, unique_inverse, and unique_values (this is done
+# to remove polymorphic return types).
-def unique(
- x: Array,
- /,
- *,
- return_counts: bool = False,
- return_index: bool = False,
- return_inverse: bool = False,
-) -> Union[Array, Tuple[Array, ...]]:
+# Note: The various unique() functions are supposed to return multiple NaNs.
+# This does not match the NumPy behavior, however, this is currently left as a
+# TODO in this implementation as this behavior may be reverted in np.unique().
+# See https://github.com/numpy/numpy/issues/20326.
+
+# Note: The functions here return a namedtuple (np.unique() returns a normal
+# tuple).
+
+class UniqueAllResult(NamedTuple):
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
+
+
+class UniqueCountsResult(NamedTuple):
+ values: Array
+ counts: Array
+
+
+class UniqueInverseResult(NamedTuple):
+ values: Array
+ inverse_indices: Array
+
+
+def unique_all(x: Array, /) -> UniqueAllResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=True,
+ return_inverse=True,
+ )
+
+ return UniqueAllResult(*[Array._new(i) for i in res])
+
+
+def unique_counts(x: Array, /) -> UniqueCountsResult:
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=False,
+ return_inverse=False,
+ )
+
+ return UniqueCountsResult(*[Array._new(i) for i in res])
+
+
+def unique_inverse(x: Array, /) -> UniqueInverseResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=True,
+ )
+ return UniqueInverseResult(*[Array._new(i) for i in res])
+
+
+def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -22,10 +85,8 @@ def unique(
"""
res = np.unique(
x._array,
- return_counts=return_counts,
- return_index=return_index,
- return_inverse=return_inverse,
+ return_counts=False,
+ return_index=False,
+ return_inverse=False,
)
- if isinstance(res, tuple):
- return tuple(Array._new(i) for i in res)
return Array._new(res)