summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjaimefrio <jaime.frio@gmail.com>2015-04-03 17:49:25 -0700
committerJaime Fernandez <jaime.frio@gmail.com>2015-04-03 20:46:17 -0700
commit77e433a26e3d02d234dd3fb5155e3467aad639d0 (patch)
tree887fdb56d0b84b0529c253ff570e1c2a384a268b
parent147c60f83f401037ff29593826d2c5729a73c2c5 (diff)
downloadnumpy-77e433a26e3d02d234dd3fb5155e3467aad639d0.tar.gz
BUG: np.repeat does not properly broadcast size 1 repeat arrays
Closes #5743
-rw-r--r--numpy/core/src/multiarray/item_selection.c44
-rw-r--r--numpy/core/tests/test_regression.py7
2 files changed, 29 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index a29c47555..c59a125e0 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject *
PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
{
npy_intp *counts;
- npy_intp n, n_outer, i, j, k, chunk, total;
- npy_intp tmp;
- int nd;
+ npy_intp n, n_outer, i, j, k, chunk;
+ npy_intp total = 0;
+ npy_bool broadcast = NPY_FALSE;
PyArrayObject *repeats = NULL;
PyObject *ap = NULL;
PyArrayObject *ret = NULL;
@@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
if (repeats == NULL) {
return NULL;
}
- nd = PyArray_NDIM(repeats);
+
+ /*
+ * Scalar and size 1 'repeat' arrays broadcast to any shape, for all
+ * other inputs the dimension must match exactly.
+ */
+ if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) {
+ broadcast = NPY_TRUE;
+ }
+
counts = (npy_intp *)PyArray_DATA(repeats);
- if ((ap=PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY))==NULL) {
+ if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) {
Py_DECREF(repeats);
return NULL;
}
aop = (PyArrayObject *)ap;
- if (nd == 1) {
- n = PyArray_DIMS(repeats)[0];
- }
- else {
- /* nd == 0 */
- n = PyArray_DIMS(aop)[axis];
- }
- if (PyArray_DIMS(aop)[axis] != n) {
- PyErr_SetString(PyExc_ValueError,
- "a.shape[axis] != len(repeats)");
+ n = PyArray_DIM(aop, axis);
+
+ if (!broadcast && PyArray_SIZE(repeats) != n) {
+ PyErr_Format(PyExc_ValueError,
+ "operands could not be broadcast together "
+ "with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0));
goto fail;
}
-
- if (nd == 0) {
- total = counts[0]*n;
+ if (broadcast) {
+ total = counts[0] * n;
}
else {
-
- total = 0;
for (j = 0; j < n; j++) {
if (counts[j] < 0) {
PyErr_SetString(PyExc_ValueError, "count < 0");
@@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
}
}
-
/* Construct new array */
PyArray_DIMS(aop)[axis] = total;
Py_INCREF(PyArray_DESCR(aop));
@@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
}
for (i = 0; i < n_outer; i++) {
for (j = 0; j < n; j++) {
- tmp = nd ? counts[j] : counts[0];
+ npy_intp tmp = broadcast ? counts[0] : counts[j];
for (k = 0; k < tmp; k++) {
memcpy(new_data, old_data, chunk);
new_data += chunk;
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 19c8d4457..fa2f52a23 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -10,6 +10,7 @@ import warnings
import tempfile
from os import path
from io import BytesIO
+from itertools import chain
import numpy as np
from numpy.testing import (
@@ -2118,6 +2119,12 @@ class TestRegression(TestCase):
assert_raises(ValueError, np.frompyfunc, passer, 32, 1)
+ def test_repeat_broadcasting(self):
+ # gh-5743
+ a = np.arange(60).reshape(3, 4, 5)
+ for axis in chain(range(-a.ndim, a.ndim), [None]):
+ assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis))
+
if __name__ == "__main__":
run_module_suite()