diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-12 18:35:52 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-05-12 19:15:36 -0700 |
commit | a3a19daff9330f0196aba90582450d022fc8798c (patch) | |
tree | 8e06f3ee5c05f8f77415d2e62df261f18fe89897 | |
parent | 0f19dae081e6678902826b195e0d3857c5b4c2b3 (diff) | |
download | numpy-a3a19daff9330f0196aba90582450d022fc8798c.tar.gz |
ENH: Allow broadcast to be called with zero arguments
Follows on from gh-6905 which reduced the limit from 2 to 1. Let's go all the way to zero.
Just as for `broadcast(broadcast(a), b)` is interpreted as `broadcast(a, b)` , this change interprets
`broadcast(broadcast(), a)` as `broadcast(a)`.
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 4 | ||||
-rw-r--r-- | numpy/lib/stride_tricks.py | 2 |
3 files changed, 22 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index 9fcdc91b2..75095f6f0 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1262,10 +1262,14 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...) int i, ntot, err=0; ntot = n + nadd; - if (ntot < 1 || ntot > NPY_MAXARGS) { + if (ntot < 0) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "n and nadd arguments must be non-negative", NPY_MAXARGS); + return NULL; + } + if (ntot > NPY_MAXARGS) { + PyErr_Format(PyExc_ValueError, + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } multi = PyArray_malloc(sizeof(PyArrayMultiIterObject)); @@ -1328,10 +1332,14 @@ PyArray_MultiIterNew(int n, ...) int i, err = 0; - if (n < 1 || n > NPY_MAXARGS) { + if (n < 0) { + PyErr_Format(PyExc_ValueError, + "n argument must be non-negative", NPY_MAXARGS); + return NULL; + } + if (n > NPY_MAXARGS) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } @@ -1409,13 +1417,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k ++n; } } - if (n < 1 || n > NPY_MAXARGS) { - if (PyErr_Occurred()) { - return NULL; - } + if (PyErr_Occurred()) { + return NULL; + } + if (n > NPY_MAXARGS) { PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); + "At most %d array objects are supported.", NPY_MAXARGS); return NULL; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 1c53f9372..c799aaf6c 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2736,6 +2736,8 @@ class TestBroadcast(object): arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)), np.empty((5, 1, 7))] mits = [np.broadcast(*arrs), + np.broadcast(np.broadcast(*arrs[:0]), np.broadcast(*arrs[0:])), + np.broadcast(np.broadcast(*arrs[:1]), np.broadcast(*arrs[1:])), np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])), np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])] for mit in mits: @@ -2760,7 +2762,7 @@ class TestBroadcast(object): arr = np.empty((5,)) for j in range(35): arrs = [arr] * j - if j < 1 or j > 32: + if j > 32: assert_raises(ValueError, np.broadcast, *arrs) else: mit = np.broadcast(*arrs) diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 0dc36e41c..fd401c57c 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -186,8 +186,6 @@ def _broadcast_shape(*args): """Returns the shape of the arrays that would result from broadcasting the supplied arrays against each other. """ - if not args: - return () # use the old-iterator because np.nditer does not handle size 0 arrays # consistently b = np.broadcast(*args[:32]) |