diff options
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 61 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 28 |
2 files changed, 75 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index e56237573..829994b1e 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1577,7 +1577,8 @@ static PyObject * arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds) { - Py_ssize_t n, i; + Py_ssize_t n = 0; + Py_ssize_t i, j, k; PyArrayMultiIterObject *multi; PyObject *arr; @@ -1587,13 +1588,27 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k return NULL; } - n = PyTuple_Size(args); + for (j = 0; j < PyTuple_Size(args); ++j) { + PyObject *obj = PyTuple_GET_ITEM(args, j); + + if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { + /* + * If obj is a multi-iterator, all its arrays will be added + * to the new multi-iterator. + */ + n += ((PyArrayMultiIterObject *)obj)->numiter; + } + else { + /* If not, will try to convert it to a single array */ + ++n; + } + } if (n < 2 || n > NPY_MAXARGS) { if (PyErr_Occurred()) { return NULL; } PyErr_Format(PyExc_ValueError, - "Need at least two and fewer than (%d) " \ + "Need at least two and fewer than (%d) " "array objects.", NPY_MAXARGS); return NULL; } @@ -1606,20 +1621,38 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k multi->numiter = n; multi->index = 0; - for (i = 0; i < n; i++) { - multi->iters[i] = NULL; - } - for (i = 0; i < n; i++) { - arr = PyArray_FromAny(PyTuple_GET_ITEM(args, i), NULL, 0, 0, 0, NULL); - if (arr == NULL) { - goto fail; + i = 0; + for (j = 0; j < PyTuple_GET_SIZE(args); ++j) { + PyObject *obj = PyTuple_GET_ITEM(args, j); + PyArrayIterObject *it; + + if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { + PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj; + + for (k = 0; k < mit->numiter; ++k) { + arr = (PyObject *)mit->iters[k]->ao; + assert (arr != NULL); + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if (it == NULL) { + goto fail; + } + multi->iters[i++] = it; + } } - if ((multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr)) - == NULL) { - goto fail; + else { + arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL); + if (arr == NULL) { + goto fail; + } + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if (it == NULL) { + goto fail; + } + multi->iters[i++] = it; + Py_DECREF(arr); } - Py_DECREF(arr); } + assert (i == n); if (PyArray_Broadcast(multi) < 0) { goto fail; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index ee304a7af..7400366ac 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2226,6 +2226,7 @@ class TestCross(TestCase): for axisc in range(-2, 2): assert_equal(np.cross(u, u, axisc=axisc).shape, (3, 4)) + def test_outer_out_param(): arr1 = np.ones((5,)) arr2 = np.ones((2,)) @@ -2236,6 +2237,7 @@ def test_outer_out_param(): assert_equal(res1, out1) assert_equal(np.outer(arr2, arr3, out2), out2) + class TestRequire(object): flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS', 'F', 'F_CONTIGUOUS', 'FORTRAN', @@ -2310,5 +2312,31 @@ class TestRequire(object): yield self.set_and_check_flag, flag, None, a +class TestBroadcast(TestCase): + def test_broadcast_in_args(self): + # gh-5881 + arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)), + np.empty((5, 1, 7))] + mits = [np.broadcast(*arrs), + np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])), + np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])] + for mit in mits: + assert_equal(mit.shape, (5, 6, 7)) + assert_equal(mit.nd, 3) + assert_equal(mit.numiter, 4) + for a, ia in zip(arrs, mit.iters): + assert_(a is ia.base) + + def test_number_of_arguments(self): + arr = np.empty((5,)) + for j in range(35): + arrs = [arr] * j + if j < 2 or j > 32: + assert_raises(ValueError, np.broadcast, *arrs) + else: + mit = np.broadcast(*arrs) + assert_equal(mit.numiter, j) + + if __name__ == "__main__": run_module_suite() |