summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Fernandez <jaime.frio@gmail.com>2015-05-17 18:57:14 -0700
committerJaime Fernandez <jaime.frio@gmail.com>2015-05-17 18:57:14 -0700
commit7debda6fc2e66bc7fc74ade526a795cc381473a8 (patch)
tree078b86b2effb4d04e589b6f90a6ace4bf95a2ce1
parent0c00f6910db141b6d514ded9a98857464a075838 (diff)
downloadnumpy-7debda6fc2e66bc7fc74ade526a795cc381473a8.tar.gz
BUG: np.broadcast should accept itself as an input
Fixes #5881
-rw-r--r--numpy/core/src/multiarray/iterators.c61
1 files changed, 47 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index e56237573..829994b1e 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1577,7 +1577,8 @@ static PyObject *
arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds)
{
- Py_ssize_t n, i;
+ Py_ssize_t n = 0;
+ Py_ssize_t i, j, k;
PyArrayMultiIterObject *multi;
PyObject *arr;
@@ -1587,13 +1588,27 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
return NULL;
}
- n = PyTuple_Size(args);
+ for (j = 0; j < PyTuple_Size(args); ++j) {
+ PyObject *obj = PyTuple_GET_ITEM(args, j);
+
+ if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
+ /*
+ * If obj is a multi-iterator, all its arrays will be added
+ * to the new multi-iterator.
+ */
+ n += ((PyArrayMultiIterObject *)obj)->numiter;
+ }
+ else {
+ /* If not, will try to convert it to a single array */
+ ++n;
+ }
+ }
if (n < 2 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
PyErr_Format(PyExc_ValueError,
- "Need at least two and fewer than (%d) " \
+ "Need at least two and fewer than (%d) "
"array objects.", NPY_MAXARGS);
return NULL;
}
@@ -1606,20 +1621,38 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
multi->numiter = n;
multi->index = 0;
- for (i = 0; i < n; i++) {
- multi->iters[i] = NULL;
- }
- for (i = 0; i < n; i++) {
- arr = PyArray_FromAny(PyTuple_GET_ITEM(args, i), NULL, 0, 0, 0, NULL);
- if (arr == NULL) {
- goto fail;
+ i = 0;
+ for (j = 0; j < PyTuple_GET_SIZE(args); ++j) {
+ PyObject *obj = PyTuple_GET_ITEM(args, j);
+ PyArrayIterObject *it;
+
+ if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
+ PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj;
+
+ for (k = 0; k < mit->numiter; ++k) {
+ arr = (PyObject *)mit->iters[k]->ao;
+ assert (arr != NULL);
+ it = (PyArrayIterObject *)PyArray_IterNew(arr);
+ if (it == NULL) {
+ goto fail;
+ }
+ multi->iters[i++] = it;
+ }
}
- if ((multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr))
- == NULL) {
- goto fail;
+ else {
+ arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL);
+ if (arr == NULL) {
+ goto fail;
+ }
+ it = (PyArrayIterObject *)PyArray_IterNew(arr);
+ if (it == NULL) {
+ goto fail;
+ }
+ multi->iters[i++] = it;
+ Py_DECREF(arr);
}
- Py_DECREF(arr);
}
+ assert (i == n);
if (PyArray_Broadcast(multi) < 0) {
goto fail;
}