summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatthew Barber <quitesimplymatt@gmail.com>2022-01-12 16:20:33 +0000
committerGitHub <noreply@github.com>2022-01-12 09:20:33 -0700
commitd7a43dfa91cc1363db64da8915db2b4b6c847b81 (patch)
treef0c9c0a151a6ffabb03af0c742546e30a2661625 /numpy
parente2d35064df262efa6eb7dfe5bfc43160c73cf685 (diff)
downloadnumpy-d7a43dfa91cc1363db64da8915db2b4b6c847b81.tar.gz
BUG: `array_api.argsort(descending=True)` respects relative sort order (#20788)
* BUG: `array_api.argsort(descending=True)` respects relative order * Regression test for stable descending `array_api.argsort()`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/array_api/_sorting_functions.py17
-rw-r--r--numpy/array_api/tests/test_sorting_functions.py23
2 files changed, 37 insertions, 3 deletions
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py
index 9cd49786c..b2a11872f 100644
--- a/numpy/array_api/_sorting_functions.py
+++ b/numpy/array_api/_sorting_functions.py
@@ -15,9 +15,20 @@ def argsort(
"""
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
- res = np.argsort(x._array, axis=axis, kind=kind)
- if descending:
- res = np.flip(res, axis=axis)
+ if not descending:
+ res = np.argsort(x._array, axis=axis, kind=kind)
+ else:
+ # As NumPy has no native descending sort, we imitate it here. Note that
+ # simply flipping the results of np.argsort(x._array, ...) would not
+ # respect the relative order like it would in native descending sorts.
+ res = np.flip(
+ np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind),
+ axis=axis,
+ )
+ # Rely on flip()/argsort() to validate axis
+ normalised_axis = axis if axis >= 0 else x.ndim + axis
+ max_i = x.shape[normalised_axis] - 1
+ res = max_i - res
return Array._new(res)
diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py
new file mode 100644
index 000000000..9848bbfeb
--- /dev/null
+++ b/numpy/array_api/tests/test_sorting_functions.py
@@ -0,0 +1,23 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "obj, axis, expected",
+ [
+ ([0, 0], -1, [0, 1]),
+ ([0, 1, 0], -1, [1, 0, 2]),
+ ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
+ ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
+ ],
+)
+def test_stable_desc_argsort(obj, axis, expected):
+ """
+ Indices respect relative order of a descending stable-sort
+
+ See https://github.com/numpy/numpy/issues/20778
+ """
+ x = xp.asarray(obj)
+ out = xp.argsort(x, axis=axis, stable=True, descending=True)
+ assert xp.all(out == xp.asarray(expected))