summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/item_selection.c17
-rw-r--r--numpy/core/tests/test_multiarray.py5
2 files changed, 9 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 486eb43ce..208b96687 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1818,26 +1818,17 @@ PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2)
}
/* Handle negative axes with standard Python indexing rules */
- if (axis1 < 0) {
- axis1 += ndim;
+ if (check_and_adjust_axis_cmsg(&axis1, ndim, "axis1") < 0) {
+ return NULL;
}
- if (axis2 < 0) {
- axis2 += ndim;
+ if (check_and_adjust_axis_cmsg(&axis2, ndim, "axis2") < 0) {
+ return NULL;
}
-
- /* Error check the two axes */
if (axis1 == axis2) {
PyErr_SetString(PyExc_ValueError,
"axis1 and axis2 cannot be the same");
return NULL;
}
- else if (axis1 < 0 || axis1 >= ndim || axis2 < 0 || axis2 >= ndim) {
- PyErr_Format(PyExc_ValueError,
- "axis1(=%d) and axis2(=%d) "
- "must be within range (ndim=%d)",
- axis1, axis2, ndim);
- return NULL;
- }
/* Get the shape and strides of the two axes */
shape = PyArray_SHAPE(self);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 3ab1b971e..0fbaacb03 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2635,6 +2635,10 @@ class TestMethods(object):
assert_equal(a.diagonal(0), [0, 5, 10])
assert_equal(a.diagonal(1), [1, 6, 11])
assert_equal(a.diagonal(-1), [4, 9])
+ assert_raises(np.AxisError, a.diagonal, axis1=0, axis2=5)
+ assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=0)
+ assert_raises(np.AxisError, a.diagonal, axis1=5, axis2=5)
+ assert_raises(ValueError, a.diagonal, axis1=1, axis2=1)
b = np.arange(8).reshape((2, 2, 2))
assert_equal(b.diagonal(), [[0, 6], [1, 7]])
@@ -2648,6 +2652,7 @@ class TestMethods(object):
# Order of axis argument doesn't matter:
assert_equal(b.diagonal(0, 2, 1), [[0, 3], [4, 7]])
+
def test_diagonal_view_notwriteable(self):
# this test is only for 1.9, the diagonal view will be
# writeable in 1.10.