summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-05-12 18:35:52 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-05-12 19:15:36 -0700
commita3a19daff9330f0196aba90582450d022fc8798c (patch)
tree8e06f3ee5c05f8f77415d2e62df261f18fe89897
parent0f19dae081e6678902826b195e0d3857c5b4c2b3 (diff)
downloadnumpy-a3a19daff9330f0196aba90582450d022fc8798c.tar.gz
ENH: Allow broadcast to be called with zero arguments
Follows on from gh-6905 which reduced the limit from 2 to 1. Let's go all the way to zero. Just as for `broadcast(broadcast(a), b)` is interpreted as `broadcast(a, b)` , this change interprets `broadcast(broadcast(), a)` as `broadcast(a)`.
-rw-r--r--numpy/core/src/multiarray/iterators.c31
-rw-r--r--numpy/core/tests/test_numeric.py4
-rw-r--r--numpy/lib/stride_tricks.py2
3 files changed, 22 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 9fcdc91b2..75095f6f0 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1262,10 +1262,14 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;
ntot = n + nadd;
- if (ntot < 1 || ntot > NPY_MAXARGS) {
+ if (ntot < 0) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "n and nadd arguments must be non-negative", NPY_MAXARGS);
+ return NULL;
+ }
+ if (ntot > NPY_MAXARGS) {
+ PyErr_Format(PyExc_ValueError,
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
@@ -1328,10 +1332,14 @@ PyArray_MultiIterNew(int n, ...)
int i, err = 0;
- if (n < 1 || n > NPY_MAXARGS) {
+ if (n < 0) {
+ PyErr_Format(PyExc_ValueError,
+ "n argument must be non-negative", NPY_MAXARGS);
+ return NULL;
+ }
+ if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
@@ -1409,13 +1417,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
- if (n < 1 || n > NPY_MAXARGS) {
- if (PyErr_Occurred()) {
- return NULL;
- }
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 1 and at most %d "
- "array objects.", NPY_MAXARGS);
+ "At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 1c53f9372..c799aaf6c 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2736,6 +2736,8 @@ class TestBroadcast(object):
arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)),
np.empty((5, 1, 7))]
mits = [np.broadcast(*arrs),
+ np.broadcast(np.broadcast(*arrs[:0]), np.broadcast(*arrs[0:])),
+ np.broadcast(np.broadcast(*arrs[:1]), np.broadcast(*arrs[1:])),
np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])),
np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])]
for mit in mits:
@@ -2760,7 +2762,7 @@ class TestBroadcast(object):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
- if j < 1 or j > 32:
+ if j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index 0dc36e41c..fd401c57c 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -186,8 +186,6 @@ def _broadcast_shape(*args):
"""Returns the shape of the arrays that would result from broadcasting the
supplied arrays against each other.
"""
- if not args:
- return ()
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])