diff options
| -rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 17 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 5 |
2 files changed, 9 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 486eb43ce..208b96687 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1818,26 +1818,17 @@ PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2) } /* Handle negative axes with standard Python indexing rules */ - if (axis1 < 0) { - axis1 += ndim; + if (check_and_adjust_axis_cmsg(&axis1, ndim, "axis1") < 0) { + return NULL; } - if (axis2 < 0) { - axis2 += ndim; + if (check_and_adjust_axis_cmsg(&axis2, ndim, "axis2") < 0) { + return NULL; } - - /* Error check the two axes */ if (axis1 == axis2) { PyErr_SetString(PyExc_ValueError, "axis1 and axis2 cannot be the same"); return NULL; } - else if (axis1 < 0 || axis1 >= ndim || axis2 < 0 || axis2 >= ndim) { - PyErr_Format(PyExc_ValueError, - "axis1(=%d) and axis2(=%d) " - "must be within range (ndim=%d)", - axis1, axis2, ndim); - return NULL; - } /* Get the shape and strides of the two axes */ shape = PyArray_SHAPE(self); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3ab1b971e..0fbaacb03 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2635,6 +2635,10 @@ class TestMethods(object): assert_equal(a.diagonal(0), [0, 5, 10]) assert_equal(a.diagonal(1), [1, 6, 11]) assert_equal(a.diagonal(-1), [4, 9]) + assert_raises(np.AxisError, a.diagonal, axis1=0, axis2=5) + assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=0) + assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=5) + assert_raises(ValueError, a.diagonal, axis1=1, axis2=1) b = np.arange(8).reshape((2, 2, 2)) assert_equal(b.diagonal(), [[0, 6], [1, 7]]) @@ -2648,6 +2652,7 @@ class TestMethods(object): # Order of axis argument doesn't matter: assert_equal(b.diagonal(0, 2, 1), [[0, 3], [4, 7]]) + def test_diagonal_view_notwriteable(self): # this test is only for 1.9, the diagonal view will be # writeable in 1.10. |
