diff options
Diffstat (limited to 'numpy/array_api/tests')
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 2 | ||||
-rw-r--r-- | numpy/array_api/tests/test_elementwise_functions.py | 2 | ||||
-rw-r--r-- | numpy/array_api/tests/test_sorting_functions.py | 23 |
3 files changed, 25 insertions, 2 deletions
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index b980bacca..1fe1dfddf 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -98,7 +98,7 @@ def test_operators(): "__mul__": "numeric", "__ne__": "all", "__or__": "integer_or_boolean", - "__pow__": "floating", + "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", "__truediv__": "floating", diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index a9274aec9..b2fb44e76 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -66,7 +66,7 @@ def test_function_types(): "negative": "numeric", "not_equal": "all", "positive": "numeric", - "pow": "floating-point", + "pow": "numeric", "remainder": "numeric", "round": "numeric", "sign": "numeric", diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py new file mode 100644 index 000000000..9848bbfeb --- /dev/null +++ b/numpy/array_api/tests/test_sorting_functions.py @@ -0,0 +1,23 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "obj, axis, expected", + [ + ([0, 0], -1, [0, 1]), + ([0, 1, 0], -1, [1, 0, 2]), + ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]), + ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]), + ], +) +def test_stable_desc_argsort(obj, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(obj) + out = xp.argsort(x, axis=axis, stable=True, descending=True) + assert xp.all(out == xp.asarray(expected)) |