diff options
author | Matthew Barber <quitesimplymatt@gmail.com> | 2022-01-12 16:20:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-12 09:20:33 -0700 |
commit | d7a43dfa91cc1363db64da8915db2b4b6c847b81 (patch) | |
tree | f0c9c0a151a6ffabb03af0c742546e30a2661625 /numpy | |
parent | e2d35064df262efa6eb7dfe5bfc43160c73cf685 (diff) | |
download | numpy-d7a43dfa91cc1363db64da8915db2b4b6c847b81.tar.gz |
BUG: `array_api.argsort(descending=True)` respects relative sort order (#20788)
* BUG: `array_api.argsort(descending=True)` respects relative order
* Regression test for stable descending `array_api.argsort()`
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/array_api/_sorting_functions.py | 17 | ||||
-rw-r--r-- | numpy/array_api/tests/test_sorting_functions.py | 23 |
2 files changed, 37 insertions, 3 deletions
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index 9cd49786c..b2a11872f 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -15,9 +15,20 @@ def argsort( """ # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = np.argsort(x._array, axis=axis, kind=kind) - if descending: - res = np.flip(res, axis=axis) + if not descending: + res = np.argsort(x._array, axis=axis, kind=kind) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of np.argsort(x._array, ...) would not + # respect the relative order like it would in native descending sorts. + res = np.flip( + np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res return Array._new(res) diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py new file mode 100644 index 000000000..9848bbfeb --- /dev/null +++ b/numpy/array_api/tests/test_sorting_functions.py @@ -0,0 +1,23 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "obj, axis, expected", + [ + ([0, 0], -1, [0, 1]), + ([0, 1, 0], -1, [1, 0, 2]), + ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]), + ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]), + ], +) +def test_stable_desc_argsort(obj, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(obj) + out = xp.argsort(x, axis=axis, stable=True, descending=True) + assert xp.all(out == xp.asarray(expected)) |