summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/iterators.c31
-rw-r--r--numpy/core/tests/test_numeric.py4
2 files changed, 22 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 9fcdc91b2..75095f6f0 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1262,10 +1262,14 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;
ntot = n + nadd;
- if (ntot < 1 || ntot > NPY_MAXARGS) {
+ if (ntot < 0) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "n and nadd arguments must be non-negative", NPY_MAXARGS);
+ return NULL;
+ }
+ if (ntot > NPY_MAXARGS) {
+ PyErr_Format(PyExc_ValueError,
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
@@ -1328,10 +1332,14 @@ PyArray_MultiIterNew(int n, ...)
int i, err = 0;
- if (n < 1 || n > NPY_MAXARGS) {
+ if (n < 0) {
+ PyErr_Format(PyExc_ValueError,
+ "n argument must be non-negative", NPY_MAXARGS);
+ return NULL;
+ }
+ if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
@@ -1409,13 +1417,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
- if (n < 1 || n > NPY_MAXARGS) {
- if (PyErr_Occurred()) {
- return NULL;
- }
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 1c53f9372..c799aaf6c 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2736,6 +2736,8 @@ class TestBroadcast(object):
arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)),
np.empty((5, 1, 7))]
mits = [np.broadcast(*arrs),
+ np.broadcast(np.broadcast(*arrs[:0]), np.broadcast(*arrs[0:])),
+ np.broadcast(np.broadcast(*arrs[:1]), np.broadcast(*arrs[1:])),
np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])),
np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])]
for mit in mits:
@@ -2760,7 +2762,7 @@ class TestBroadcast(object):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
- if j < 1 or j > 32:
+ if j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)