summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwarren <warren.weckesser@gmail.com>2021-10-01 23:11:58 -0400
committerwarren <warren.weckesser@gmail.com>2021-10-01 23:39:12 -0400
commit8196c2a46fb621580a53ad5f7b2bd08cd154e870 (patch)
tree8ad132eff44a35f2e63a4c322b1eb362e9451c36
parent6ad932c41281a6b7e525aba8a1057c09065251fa (diff)
downloadnumpy-8196c2a46fb621580a53ad5f7b2bd08cd154e870.tar.gz
ENH: core: More informative error message for broadcast(*args)
When broadcast(*args) fails because of a shape mismatch, include in the error message which arguments caused the mismatch and what their shapes are. For example, instead of >>> np.broadcast([[0, 0, 0]], [[1], [1]], [2, 2]) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: shape mismatch: objects cannot be broadcast to a single shape we now get >>> np.broadcast([[0, 0, 0]], [[1], [1]], [2, 2]) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape (1, 3) and arg 2 with shape (2,). This also affects broadcast_arrays() and broadcast_shapes(). Closes gh-8345.
-rw-r--r--numpy/core/src/multiarray/iterators.c36
-rw-r--r--numpy/core/tests/test_numeric.py6
2 files changed, 38 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 36bfaa7cf..f959162fd 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1124,6 +1124,35 @@ NPY_NO_EXPORT PyTypeObject PyArrayIter_Type = {
/** END of Array Iterator **/
+
+static int
+set_shape_mismatch_exception(PyArrayMultiIterObject *mit, int i1, int i2)
+{
+ PyObject *shape1, *shape2, *msg;
+
+ shape1 = PyObject_GetAttrString((PyObject *) mit->iters[i1]->ao, "shape");
+ if (shape1 == NULL) {
+ return -1;
+ }
+ shape2 = PyObject_GetAttrString((PyObject *) mit->iters[i2]->ao, "shape");
+ if (shape2 == NULL) {
+ Py_DECREF(shape1);
+ return -1;
+ }
+ msg = PyUnicode_FromFormat("shape mismatch: objects cannot be broadcast "
+ "to a single shape. Mismatch is between arg %d "
+ "with shape %S and arg %d with shape %S.",
+ i1, shape1, i2, shape2);
+ Py_DECREF(shape1);
+ Py_DECREF(shape2);
+ if (msg == NULL) {
+ return -1;
+ }
+ PyErr_SetObject(PyExc_ValueError, msg);
+ Py_DECREF(msg);
+ return 0;
+}
+
/* Adjust dimensionality and strides for index object iterators
--- i.e. broadcast
*/
@@ -1132,6 +1161,7 @@ NPY_NO_EXPORT int
PyArray_Broadcast(PyArrayMultiIterObject *mit)
{
int i, nd, k, j;
+ int src_iter = -1; /* Initializing avoids a compiler warning. */
npy_intp tmp;
PyArrayIterObject *it;
@@ -1155,12 +1185,10 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit)
}
if (mit->dimensions[i] == 1) {
mit->dimensions[i] = tmp;
+ src_iter = j;
}
else if (mit->dimensions[i] != tmp) {
- PyErr_SetString(PyExc_ValueError,
- "shape mismatch: objects" \
- " cannot be broadcast" \
- " to a single shape");
+ set_shape_mismatch_exception(mit, src_iter, j);
return -1;
}
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 4510333a1..e36f76c53 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -3511,6 +3511,12 @@ class TestBroadcast:
assert_raises(ValueError, np.broadcast, 1, **{'x': 1})
+ def test_shape_mismatch_error_message(self):
+ with pytest.raises(ValueError, match=r"arg 0 with shape \(1, 3\) and "
+ r"arg 2 with shape \(2,\)"):
+ np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])
+
+
class TestKeepdims:
class sub_array(np.ndarray):