summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/iterators.c61
-rw-r--r--numpy/core/tests/test_numeric.py28
2 files changed, 75 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index e56237573..829994b1e 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1577,7 +1577,8 @@ static PyObject *
arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds)
{
- Py_ssize_t n, i;
+ Py_ssize_t n = 0;
+ Py_ssize_t i, j, k;
PyArrayMultiIterObject *multi;
PyObject *arr;
@@ -1587,13 +1588,27 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
return NULL;
}
- n = PyTuple_Size(args);
+ for (j = 0; j < PyTuple_Size(args); ++j) {
+ PyObject *obj = PyTuple_GET_ITEM(args, j);
+
+ if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
+ /*
+ * If obj is a multi-iterator, all its arrays will be added
+ * to the new multi-iterator.
+ */
+ n += ((PyArrayMultiIterObject *)obj)->numiter;
+ }
+ else {
+ /* If not, will try to convert it to a single array */
+ ++n;
+ }
+ }
if (n < 2 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
PyErr_Format(PyExc_ValueError,
- "Need at least two and fewer than (%d) " \
+ "Need at least two and fewer than (%d) "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1606,20 +1621,38 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
multi->numiter = n;
multi->index = 0;
- for (i = 0; i < n; i++) {
- multi->iters[i] = NULL;
- }
- for (i = 0; i < n; i++) {
- arr = PyArray_FromAny(PyTuple_GET_ITEM(args, i), NULL, 0, 0, 0, NULL);
- if (arr == NULL) {
- goto fail;
+ i = 0;
+ for (j = 0; j < PyTuple_GET_SIZE(args); ++j) {
+ PyObject *obj = PyTuple_GET_ITEM(args, j);
+ PyArrayIterObject *it;
+
+ if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
+ PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj;
+
+ for (k = 0; k < mit->numiter; ++k) {
+ arr = (PyObject *)mit->iters[k]->ao;
+ assert (arr != NULL);
+ it = (PyArrayIterObject *)PyArray_IterNew(arr);
+ if (it == NULL) {
+ goto fail;
+ }
+ multi->iters[i++] = it;
+ }
}
- if ((multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr))
- == NULL) {
- goto fail;
+ else {
+ arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL);
+ if (arr == NULL) {
+ goto fail;
+ }
+ it = (PyArrayIterObject *)PyArray_IterNew(arr);
+ if (it == NULL) {
+ goto fail;
+ }
+ multi->iters[i++] = it;
+ Py_DECREF(arr);
}
- Py_DECREF(arr);
}
+ assert (i == n);
if (PyArray_Broadcast(multi) < 0) {
goto fail;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index ee304a7af..7400366ac 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2226,6 +2226,7 @@ class TestCross(TestCase):
for axisc in range(-2, 2):
assert_equal(np.cross(u, u, axisc=axisc).shape, (3, 4))
+
def test_outer_out_param():
arr1 = np.ones((5,))
arr2 = np.ones((2,))
@@ -2236,6 +2237,7 @@ def test_outer_out_param():
assert_equal(res1, out1)
assert_equal(np.outer(arr2, arr3, out2), out2)
+
class TestRequire(object):
flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS',
'F', 'F_CONTIGUOUS', 'FORTRAN',
@@ -2310,5 +2312,31 @@ class TestRequire(object):
yield self.set_and_check_flag, flag, None, a
+class TestBroadcast(TestCase):
+ def test_broadcast_in_args(self):
+ # gh-5881
+ arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)),
+ np.empty((5, 1, 7))]
+ mits = [np.broadcast(*arrs),
+ np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])),
+ np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])]
+ for mit in mits:
+ assert_equal(mit.shape, (5, 6, 7))
+ assert_equal(mit.nd, 3)
+ assert_equal(mit.numiter, 4)
+ for a, ia in zip(arrs, mit.iters):
+ assert_(a is ia.base)
+
+ def test_number_of_arguments(self):
+ arr = np.empty((5,))
+ for j in range(35):
+ arrs = [arr] * j
+ if j < 2 or j > 32:
+ assert_raises(ValueError, np.broadcast, *arrs)
+ else:
+ mit = np.broadcast(*arrs)
+ assert_equal(mit.numiter, j)
+
+
if __name__ == "__main__":
run_module_suite()