diff options
author | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 18:57:14 -0700 |
---|---|---|
committer | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 18:57:14 -0700 |
commit | 7debda6fc2e66bc7fc74ade526a795cc381473a8 (patch) | |
tree | 078b86b2effb4d04e589b6f90a6ace4bf95a2ce1 | |
parent | 0c00f6910db141b6d514ded9a98857464a075838 (diff) | |
download | numpy-7debda6fc2e66bc7fc74ade526a795cc381473a8.tar.gz |
BUG: np.broadcast should accept itself as an input
Fixes #5881
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 61 |
1 files changed, 47 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index e56237573..829994b1e 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1577,7 +1577,8 @@ static PyObject * arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds) { - Py_ssize_t n, i; + Py_ssize_t n = 0; + Py_ssize_t i, j, k; PyArrayMultiIterObject *multi; PyObject *arr; @@ -1587,13 +1588,27 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k return NULL; } - n = PyTuple_Size(args); + for (j = 0; j < PyTuple_Size(args); ++j) { + PyObject *obj = PyTuple_GET_ITEM(args, j); + + if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { + /* + * If obj is a multi-iterator, all its arrays will be added + * to the new multi-iterator. + */ + n += ((PyArrayMultiIterObject *)obj)->numiter; + } + else { + /* If not, will try to convert it to a single array */ + ++n; + } + } if (n < 2 || n > NPY_MAXARGS) { if (PyErr_Occurred()) { return NULL; } PyErr_Format(PyExc_ValueError, - "Need at least two and fewer than (%d) " \ + "Need at least two and fewer than (%d) " "array objects.", NPY_MAXARGS); return NULL; } @@ -1606,20 +1621,38 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k multi->numiter = n; multi->index = 0; - for (i = 0; i < n; i++) { - multi->iters[i] = NULL; - } - for (i = 0; i < n; i++) { - arr = PyArray_FromAny(PyTuple_GET_ITEM(args, i), NULL, 0, 0, 0, NULL); - if (arr == NULL) { - goto fail; + i = 0; + for (j = 0; j < PyTuple_GET_SIZE(args); ++j) { + PyObject *obj = PyTuple_GET_ITEM(args, j); + PyArrayIterObject *it; + + if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { + PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj; + + for (k = 0; k < mit->numiter; ++k) { + arr = (PyObject *)mit->iters[k]->ao; + assert (arr != NULL); + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if (it == NULL) { + goto fail; + } + multi->iters[i++] = it; + } } - if ((multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr)) - == NULL) { - goto fail; + else { + arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL); + if (arr == NULL) { + goto fail; + } + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if (it == NULL) { + goto fail; + } + multi->iters[i++] = it; + Py_DECREF(arr); } - Py_DECREF(arr); } + assert (i == n); if (PyArray_Broadcast(multi) < 0) { goto fail; } |