diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 5 |
2 files changed, 15 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index f0ef8ba3b..45c019f49 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1564,6 +1564,16 @@ PyArray_LexSort(PyObject *sort_keys, int axis) /* Now we can check the axis */ nd = PyArray_NDIM(mps[0]); + /* + * Special case letting axis={-1,0} slip through for scalars, + * for backwards compatibility reasons. + */ + if (nd == 0 && (axis == 0 || axis == -1)) { + /* TODO: can we deprecate this? */ + } + else if (check_and_adjust_axis(&axis, nd) < 0) { + goto fail; + } if ((nd == 0) || (PyArray_SIZE(mps[0]) <= 1)) { /* empty/single element case */ ret = (PyArrayObject *)PyArray_NewFromDescr( @@ -1579,9 +1589,6 @@ PyArray_LexSort(PyObject *sort_keys, int axis) } goto finish; } - if (check_and_adjust_axis(&axis, nd) < 0) { - goto fail; - } for (i = 0; i < n; i++) { its[i] = (PyArrayIterObject *)PyArray_IterAllButAxis( diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index fb969e5f8..96a6d810f 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -452,6 +452,11 @@ class TestRegression: xs.strides = (16, 16) assert np.lexsort((xs,), axis=0).shape[0] == 2 + def test_lexsort_invalid_axis(self): + assert_raises(np.AxisError, np.lexsort, (np.arange(1),), axis=2) + assert_raises(np.AxisError, np.lexsort, (np.array([]),), axis=1) + assert_raises(np.AxisError, np.lexsort, (np.array(1),), axis=10) + def test_lexsort_zerolen_element(self): dt = np.dtype([]) # a void dtype with no fields xs = np.empty(4, dt) |