summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/iterators.c12
-rw-r--r--numpy/core/tests/test_indexing.py5
-rw-r--r--numpy/core/tests/test_numeric.py11
-rw-r--r--numpy/lib/stride_tricks.py3
4 files changed, 17 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 702f9e21a..5099e3e19 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1456,9 +1456,9 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;
ntot = n + nadd;
- if (ntot < 2 || ntot > NPY_MAXARGS) {
+ if (ntot < 1 || ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1522,9 +1522,9 @@ PyArray_MultiIterNew(int n, ...)
int i, err = 0;
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1603,12 +1603,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
- if (n < 2 || n > NPY_MAXARGS) {
+ if (n < 1 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
PyErr_Format(PyExc_ValueError,
- "Need at least 2 and at most %d "
+ "Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
return NULL;
}
diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py
index 38280d05e..deb2130b7 100644
--- a/numpy/core/tests/test_indexing.py
+++ b/numpy/core/tests/test_indexing.py
@@ -895,10 +895,7 @@ class TestMultiIndexingAutomated(TestCase):
+ arr.shape[ax + len(indx[1:]):]))
# Check if broadcasting works
- if len(indx[1:]) != 1:
- res = np.broadcast(*indx[1:]) # raises ValueError...
- else:
- res = indx[1]
+ res = np.broadcast(*indx[1:])
# unfortunately the indices might be out of bounds. So check
# that first, and use mode='wrap' then. However only if
# there are any indices...
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b7e146b5a..d63118080 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2207,11 +2207,20 @@ class TestBroadcast(TestCase):
for a, ia in zip(arrs, mit.iters):
assert_(a is ia.base)
+ def test_broadcast_single_arg(self):
+ # gh-6899
+ arrs = [np.empty((5, 6, 7))]
+ mit = np.broadcast(*arrs)
+ assert_equal(mit.shape, (5, 6, 7))
+ assert_equal(mit.nd, 3)
+ assert_equal(mit.numiter, 1)
+ assert_(arrs[0] is mit.iters[0].base)
+
def test_number_of_arguments(self):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
- if j < 2 or j > 32:
+ if j < 1 or j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index f4b43a5a9..4c23ab355 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -121,9 +121,6 @@ def _broadcast_shape(*args):
"""
if not args:
raise ValueError('must provide at least one argument')
- if len(args) == 1:
- # a single argument does not work with np.broadcast
- return np.asarray(args[0]).shape
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])