diff options
author | Simon Gibbons <simongibbons@gmail.com> | 2016-04-08 14:58:44 +0100 |
---|---|---|
committer | Simon Gibbons <simongibbons@gmail.com> | 2016-04-08 15:07:57 +0100 |
commit | c2ec8187b4d53d125d00b96b9891ff3c9da7e823 (patch) | |
tree | e1c8da373941edee201dde719537bff3516dcf96 | |
parent | 71575f1e0167a5fc07a761336439807544d8fc5a (diff) | |
download | numpy-c2ec8187b4d53d125d00b96b9891ff3c9da7e823.tar.gz |
BUG: Floating exception with invalid axis in np.lexsort
When an invalid axis was passed into PyArray_LexSort it
would attempt to create a set of iterators to ignore that
axis before checking to see if the axis was valid. This
would cause a floating exception as the dimension of
the invalid axis would on occasion return zero.
This fixes that by moving the axis to before the iterator
creation.
Fixes #7528
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 3 |
2 files changed, 11 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 80dc8201f..dcd3322c4 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1441,11 +1441,6 @@ PyArray_LexSort(PyObject *sort_keys, int axis) && PyDataType_FLAGCHK(PyArray_DESCR(mps[i]), NPY_NEEDS_PYAPI)) { object = 1; } - its[i] = (PyArrayIterObject *)PyArray_IterAllButAxis( - (PyObject *)mps[i], &axis); - if (its[i] == NULL) { - goto fail; - } } /* Now we can check the axis */ @@ -1472,6 +1467,14 @@ PyArray_LexSort(PyObject *sort_keys, int axis) goto fail; } + for (i = 0; i < n; i++) { + its[i] = (PyArrayIterObject *)PyArray_IterAllButAxis( + (PyObject *)mps[i], &axis); + if (its[i] == NULL) { + goto fail; + } + } + /* Now do the sorting */ ret = (PyArrayObject *)PyArray_New(&PyArray_Type, PyArray_NDIM(mps[0]), PyArray_DIMS(mps[0]), NPY_INTP, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 8cd28f88f..4a2a232af 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3614,6 +3614,9 @@ class TestLexsort(TestCase): u, v = np.array(u, dtype='object'), np.array(v, dtype='object') assert_array_equal(idx, np.lexsort((u, v))) + def test_invalid_axis(self): # gh-7528 + x = np.linspace(0., 1., 42*3).reshape(42, 3) + assert_raises(ValueError, np.lexsort, x, axis=2) class TestIO(object): """Test tofile, fromfile, tobytes, and fromstring""" |