summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime <jaime.frio@gmail.com>2016-01-06 22:30:24 +0100
committerJaime <jaime.frio@gmail.com>2016-01-06 22:30:24 +0100
commit55b5972174a93012d6c198e25080c9b8d4c01f8f (patch)
tree8a77e2b31c92acd480599ccc4d7458a2ec74b843
parent318c243ef77cc8fb0c9c714e68a09bc3e05a47b9 (diff)
parent816cd4983b0c6cddf3c2e51331d822188ddc7aa0 (diff)
downloadnumpy-55b5972174a93012d6c198e25080c9b8d4c01f8f.tar.gz
Merge pull request #6905 from kohr-h/issue-6899__broadcast_with_one_arg
ENH: allow single input argument in numpy.broadcast
-rw-r--r--doc/release/1.11.0-notes.rst11
-rw-r--r--numpy/core/src/multiarray/iterators.c12
-rw-r--r--numpy/core/tests/test_indexing.py5
-rw-r--r--numpy/core/tests/test_numeric.py11
-rw-r--r--numpy/lib/stride_tricks.py3
5 files changed, 28 insertions, 14 deletions
diff --git a/doc/release/1.11.0-notes.rst b/doc/release/1.11.0-notes.rst
index c15936cc3..e66e680d3 100644
--- a/doc/release/1.11.0-notes.rst
+++ b/doc/release/1.11.0-notes.rst
@@ -133,6 +133,17 @@ diskspace on filesystems that support it.
Changes
=======
+*np.broadcast* can now be called with a single argument
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+The resulting object in that case will simply mimic iteration over
+a single array. This change obsoletes distinctions like
+
+ if len(x) == 1:
+ shape = x[0].shape
+ else:
+ shape = np.broadcast(*x).shape
+
+Instead, ``np.broadcast`` can be used in all cases.
Deprecations
============
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 702f9e21a..5099e3e19 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1456,9 +1456,9 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;
ntot = n + nadd;
- if (ntot < 2 || ntot > NPY_MAXARGS) {
+ if (ntot < 1 || ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1522,9 +1522,9 @@ PyArray_MultiIterNew(int n, ...)
int i, err = 0;
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1603,12 +1603,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py
index 38280d05e..deb2130b7 100644
--- a/numpy/core/tests/test_indexing.py
+++ b/numpy/core/tests/test_indexing.py
@@ -895,10 +895,7 @@ class TestMultiIndexingAutomated(TestCase):
+ arr.shape[ax + len(indx[1:]):]))
# Check if broadcasting works
- if len(indx[1:]) != 1:
- res = np.broadcast(*indx[1:]) # raises ValueError...
- else:
- res = indx[1]
+ res = np.broadcast(*indx[1:])
# unfortunately the indices might be out of bounds. So check
# that first, and use mode='wrap' then. However only if
# there are any indices...
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b7e146b5a..d63118080 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2207,11 +2207,20 @@ class TestBroadcast(TestCase):
for a, ia in zip(arrs, mit.iters):
assert_(a is ia.base)
+ def test_broadcast_single_arg(self):
+ # gh-6899
+ arrs = [np.empty((5, 6, 7))]
+ mit = np.broadcast(*arrs)
+ assert_equal(mit.shape, (5, 6, 7))
+ assert_equal(mit.nd, 3)
+ assert_equal(mit.numiter, 1)
+ assert_(arrs[0] is mit.iters[0].base)
+
def test_number_of_arguments(self):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
- if j < 2 or j > 32:
+ if j < 1 or j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index f4b43a5a9..4c23ab355 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -121,9 +121,6 @@ def _broadcast_shape(*args):
"""
if not args:
raise ValueError('must provide at least one argument')
- if len(args) == 1:
- # a single argument does not work with np.broadcast
- return np.asarray(args[0]).shape
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])