diff options
author | Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> | 2021-11-18 14:32:20 -0800 |
---|---|---|
committer | Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> | 2021-11-18 14:32:20 -0800 |
commit | 9fe353e048bc317099bb13cae4cb5a4de8c0c656 (patch) | |
tree | 4976e4328da9906d99cd2478547ca1e2829e071d /numpy/array_api/_set_functions.py | |
parent | 66ca468b97873c1c8bfe69bddac6df633780c71c (diff) | |
parent | 5e9ce0c0529e3085498ac892941a020a65c7369a (diff) | |
download | numpy-9fe353e048bc317099bb13cae4cb5a4de8c0c656.tar.gz |
Merge branch 'as_min_max' of https://github.com/Developer-Ecosystem-Engineering/numpy into as_min_max
Diffstat (limited to 'numpy/array_api/_set_functions.py')
-rw-r--r-- | numpy/array_api/_set_functions.py | 89 |
1 files changed, 75 insertions, 14 deletions
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index 357f238f5..05ee7e555 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -2,19 +2,82 @@ from __future__ import annotations from ._array_object import Array -from typing import Tuple, Union +from typing import NamedTuple import numpy as np +# Note: np.unique() is split into four functions in the array API: +# unique_all, unique_counts, unique_inverse, and unique_values (this is done +# to remove polymorphic return types). -def unique( - x: Array, - /, - *, - return_counts: bool = False, - return_index: bool = False, - return_inverse: bool = False, -) -> Union[Array, Tuple[Array, ...]]: +# Note: The various unique() functions are supposed to return multiple NaNs. +# This does not match the NumPy behavior, however, this is currently left as a +# TODO in this implementation as this behavior may be reverted in np.unique(). +# See https://github.com/numpy/numpy/issues/20326. + +# Note: The functions here return a namedtuple (np.unique() returns a normal +# tuple). + +class UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array + + +class UniqueCountsResult(NamedTuple): + values: Array + counts: Array + + +class UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array + + +def unique_all(x: Array, /) -> UniqueAllResult: + """ + Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. + + See its docstring for more information. + """ + res = np.unique( + x._array, + return_counts=True, + return_index=True, + return_inverse=True, + ) + + return UniqueAllResult(*[Array._new(i) for i in res]) + + +def unique_counts(x: Array, /) -> UniqueCountsResult: + res = np.unique( + x._array, + return_counts=True, + return_index=False, + return_inverse=False, + ) + + return UniqueCountsResult(*[Array._new(i) for i in res]) + + +def unique_inverse(x: Array, /) -> UniqueInverseResult: + """ + Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. + + See its docstring for more information. + """ + res = np.unique( + x._array, + return_counts=False, + return_index=False, + return_inverse=True, + ) + return UniqueInverseResult(*[Array._new(i) for i in res]) + + +def unique_values(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. @@ -22,10 +85,8 @@ def unique( """ res = np.unique( x._array, - return_counts=return_counts, - return_index=return_index, - return_inverse=return_inverse, + return_counts=False, + return_index=False, + return_inverse=False, ) - if isinstance(res, tuple): - return tuple(Array._new(i) for i in res) return Array._new(res) |