diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 4 |
2 files changed, 22 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index 9fcdc91b2..75095f6f0 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1262,10 +1262,14 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...) int i, ntot, err=0; ntot = n + nadd; - if (ntot < 1 || ntot > NPY_MAXARGS) { + if (ntot < 0) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "n and nadd arguments must be non-negative", NPY_MAXARGS); + return NULL; + } + if (ntot > NPY_MAXARGS) { + PyErr_Format(PyExc_ValueError, + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } multi = PyArray_malloc(sizeof(PyArrayMultiIterObject)); @@ -1328,10 +1332,14 @@ PyArray_MultiIterNew(int n, ...) int i, err = 0; - if (n < 1 || n > NPY_MAXARGS) { + if (n < 0) { + PyErr_Format(PyExc_ValueError, + "n argument must be non-negative", NPY_MAXARGS); + return NULL; + } + if (n > NPY_MAXARGS) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } @@ -1409,13 +1417,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k ++n; } } - if (n < 1 || n > NPY_MAXARGS) { - if (PyErr_Occurred()) { - return NULL; - } + if (PyErr_Occurred()) { + return NULL; + } + if (n > NPY_MAXARGS) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 1c53f9372..c799aaf6c 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2736,6 +2736,8 @@ class TestBroadcast(object): arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)), np.empty((5, 1, 7))] mits = [np.broadcast(*arrs), + np.broadcast(np.broadcast(*arrs[:0]), np.broadcast(*arrs[0:])), + np.broadcast(np.broadcast(*arrs[:1]), np.broadcast(*arrs[1:])), np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])), np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])] for mit in mits: @@ -2760,7 +2762,7 @@ class TestBroadcast(object): arr = np.empty((5,)) for j in range(35): arrs = [arr] * j - if j < 1 or j > 32: + if j > 32: assert_raises(ValueError, np.broadcast, *arrs) else: mit = np.broadcast(*arrs) |