diff options
author | Pauli Virtanen <pav@iki.fi> | 2009-11-07 21:06:31 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2009-11-07 21:06:31 +0000 |
commit | faea0c8820fc773a100a4875afb42fe8f1944b3d (patch) | |
tree | dac91c2ccf23bb9ad27abb11d7998415ab9743b4 | |
parent | ca35f530266c73ed32b7d7e35720151c49426554 (diff) | |
download | numpy-faea0c8820fc773a100a4875afb42fe8f1944b3d.tar.gz |
Allow only axis=0 and axis=None for 0-d arrays, and disallow axis>MAX_DIMS (addresses #1286)
These changes should catch errors earlier by raising exceptions, instead
of resulting to unexpected behavior.
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 13 |
2 files changed, 20 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index a07998a8c..80a9d7988 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -2424,6 +2424,9 @@ PyArray_CopyInto(PyArrayObject *dest, PyArrayObject *src) /*NUMPY_API PyArray_CheckAxis + + check that axis is valid + convert 0-d arrays to 1-d arrays */ NPY_NO_EXPORT PyObject * PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags) @@ -2431,14 +2434,16 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags) PyObject *temp1, *temp2; int n = arr->nd; - if ((*axis >= MAX_DIMS) || (n==0)) { + if (*axis == MAX_DIMS || n == 0) { if (n != 1) { temp1 = PyArray_Ravel(arr,0); if (temp1 == NULL) { *axis = 0; return NULL; } - *axis = PyArray_NDIM(temp1)-1; + if (*axis == MAX_DIMS) { + *axis = PyArray_NDIM(temp1)-1; + } } else { temp1 = (PyObject *)arr; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 8c426bc5f..f40e5e9ca 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -644,6 +644,19 @@ class TestArgmax(TestCase): axes.remove(i) assert all(amax == aargmax.choose(*a.transpose(i,*axes))) +class TestMinMax(TestCase): + def test_scalar(self): + assert_raises(ValueError, np.amax, 1, 1) + assert_raises(ValueError, np.amin, 1, 1) + + assert_equal(np.amax(1, axis=0), 1) + assert_equal(np.amin(1, axis=0), 1) + assert_equal(np.amax(1, axis=None), 1) + assert_equal(np.amin(1, axis=None), 1) + + def test_axis(self): + assert_raises(ValueError, np.amax, [1,2,3], 1000) + assert_equal(np.amax([[1,2,3]], axis=1), 3) class TestNewaxis(TestCase): def test_basic(self): |