summaryrefslogtreecommitdiff
path: root/numpy/array_api/_set_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2022-06-03 15:21:19 -0600
committerAaron Meurer <asmeurer@gmail.com>2022-06-03 15:21:19 -0600
commit12f83eb7337b840e3cd9026779b99b1af8033bf3 (patch)
treee35198721835ccab8b0cbe03751d0eb1f2f43798 /numpy/array_api/_set_functions.py
parent2cc66cf0965c48ce6f804d3b0dc7be3d87afe45e (diff)
downloadnumpy-12f83eb7337b840e3cd9026779b99b1af8033bf3.tar.gz
Fix the array API unique_*() functions to not compare nans as equal
The spec requires this, but it is only now possible to implement with the new equal_nan flag in np.unique().
Diffstat (limited to 'numpy/array_api/_set_functions.py')
-rw-r--r--numpy/array_api/_set_functions.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index db9370f84..0b4132cf8 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -46,6 +46,7 @@ def unique_all(x: Array, /) -> UniqueAllResult:
return_counts=True,
return_index=True,
return_inverse=True,
+ equal_nan=False,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
@@ -64,6 +65,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
return_counts=True,
return_index=False,
return_inverse=False,
+ equal_nan=False,
)
return UniqueCountsResult(*[Array._new(i) for i in res])
@@ -80,6 +82,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
return_counts=False,
return_index=False,
return_inverse=True,
+ equal_nan=False,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
@@ -98,5 +101,6 @@ def unique_values(x: Array, /) -> Array:
return_counts=False,
return_index=False,
return_inverse=False,
+ equal_nan=False,
)
return Array._new(res)