diff options
author | jaimefrio <jaime.frio@gmail.com> | 2015-04-03 17:49:25 -0700 |
---|---|---|
committer | Jaime Fernandez <jaime.frio@gmail.com> | 2015-04-03 20:46:17 -0700 |
commit | 77e433a26e3d02d234dd3fb5155e3467aad639d0 (patch) | |
tree | 887fdb56d0b84b0529c253ff570e1c2a384a268b | |
parent | 147c60f83f401037ff29593826d2c5729a73c2c5 (diff) | |
download | numpy-77e433a26e3d02d234dd3fb5155e3467aad639d0.tar.gz |
BUG: np.repeat does not properly broadcast size 1 repeat arrays
Closes #5743
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 44 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 7 |
2 files changed, 29 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index a29c47555..c59a125e0 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject * PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) { npy_intp *counts; - npy_intp n, n_outer, i, j, k, chunk, total; - npy_intp tmp; - int nd; + npy_intp n, n_outer, i, j, k, chunk; + npy_intp total = 0; + npy_bool broadcast = NPY_FALSE; PyArrayObject *repeats = NULL; PyObject *ap = NULL; PyArrayObject *ret = NULL; @@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) if (repeats == NULL) { return NULL; } - nd = PyArray_NDIM(repeats); + + /* + * Scalar and size 1 'repeat' arrays broadcast to any shape, for all + * other inputs the dimension must match exactly. + */ + if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) { + broadcast = NPY_TRUE; + } + counts = (npy_intp *)PyArray_DATA(repeats); - if ((ap=PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY))==NULL) { + if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) { Py_DECREF(repeats); return NULL; } aop = (PyArrayObject *)ap; - if (nd == 1) { - n = PyArray_DIMS(repeats)[0]; - } - else { - /* nd == 0 */ - n = PyArray_DIMS(aop)[axis]; - } - if (PyArray_DIMS(aop)[axis] != n) { - PyErr_SetString(PyExc_ValueError, - "a.shape[axis] != len(repeats)"); + n = PyArray_DIM(aop, axis); + + if (!broadcast && PyArray_SIZE(repeats) != n) { + PyErr_Format(PyExc_ValueError, + "operands could not be broadcast together " + "with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0)); goto fail; } - - if (nd == 0) { - total = counts[0]*n; + if (broadcast) { + total = counts[0] * n; } else { - - total = 0; for (j = 0; j < n; j++) { if (counts[j] < 0) { PyErr_SetString(PyExc_ValueError, "count < 0"); @@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) } } - /* Construct new array */ PyArray_DIMS(aop)[axis] = total; Py_INCREF(PyArray_DESCR(aop)); @@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis) } for (i = 0; i < n_outer; i++) { for (j = 0; j < n; j++) { - tmp = nd ? counts[j] : counts[0]; + npy_intp tmp = broadcast ? counts[0] : counts[j]; for (k = 0; k < tmp; k++) { memcpy(new_data, old_data, chunk); new_data += chunk; diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 19c8d4457..fa2f52a23 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -10,6 +10,7 @@ import warnings import tempfile from os import path from io import BytesIO +from itertools import chain import numpy as np from numpy.testing import ( @@ -2118,6 +2119,12 @@ class TestRegression(TestCase): assert_raises(ValueError, np.frompyfunc, passer, 32, 1) + def test_repeat_broadcasting(self): + # gh-5743 + a = np.arange(60).reshape(3, 4, 5) + for axis in chain(range(-a.ndim, a.ndim), [None]): + assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis)) + if __name__ == "__main__": run_module_suite() |