summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-11-07 21:06:31 +0000
committerPauli Virtanen <pav@iki.fi>2009-11-07 21:06:31 +0000
commitfaea0c8820fc773a100a4875afb42fe8f1944b3d (patch)
treedac91c2ccf23bb9ad27abb11d7998415ab9743b4
parentca35f530266c73ed32b7d7e35720151c49426554 (diff)
downloadnumpy-faea0c8820fc773a100a4875afb42fe8f1944b3d.tar.gz
Allow only axis=0 and axis=None for 0-d arrays, and disallow axis>MAX_DIMS (addresses #1286)
These changes should catch errors earlier by raising exceptions, instead of resulting to unexpected behavior.
-rw-r--r--numpy/core/src/multiarray/ctors.c9
-rw-r--r--numpy/core/tests/test_multiarray.py13
2 files changed, 20 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index a07998a8c..80a9d7988 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -2424,6 +2424,9 @@ PyArray_CopyInto(PyArrayObject *dest, PyArrayObject *src)
/*NUMPY_API
PyArray_CheckAxis
+
+ check that axis is valid
+ convert 0-d arrays to 1-d arrays
*/
NPY_NO_EXPORT PyObject *
PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
@@ -2431,14 +2434,16 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
PyObject *temp1, *temp2;
int n = arr->nd;
- if ((*axis >= MAX_DIMS) || (n==0)) {
+ if (*axis == MAX_DIMS || n == 0) {
if (n != 1) {
temp1 = PyArray_Ravel(arr,0);
if (temp1 == NULL) {
*axis = 0;
return NULL;
}
- *axis = PyArray_NDIM(temp1)-1;
+ if (*axis == MAX_DIMS) {
+ *axis = PyArray_NDIM(temp1)-1;
+ }
}
else {
temp1 = (PyObject *)arr;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 8c426bc5f..f40e5e9ca 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -644,6 +644,19 @@ class TestArgmax(TestCase):
axes.remove(i)
assert all(amax == aargmax.choose(*a.transpose(i,*axes)))
+class TestMinMax(TestCase):
+ def test_scalar(self):
+ assert_raises(ValueError, np.amax, 1, 1)
+ assert_raises(ValueError, np.amin, 1, 1)
+
+ assert_equal(np.amax(1, axis=0), 1)
+ assert_equal(np.amin(1, axis=0), 1)
+ assert_equal(np.amax(1, axis=None), 1)
+ assert_equal(np.amin(1, axis=None), 1)
+
+ def test_axis(self):
+ assert_raises(ValueError, np.amax, [1,2,3], 1000)
+ assert_equal(np.amax([[1,2,3]], axis=1), 3)
class TestNewaxis(TestCase):
def test_basic(self):