diff options
author | warren <warren.weckesser@gmail.com> | 2021-10-01 23:11:58 -0400 |
---|---|---|
committer | warren <warren.weckesser@gmail.com> | 2021-10-01 23:39:12 -0400 |
commit | 8196c2a46fb621580a53ad5f7b2bd08cd154e870 (patch) | |
tree | 8ad132eff44a35f2e63a4c322b1eb362e9451c36 | |
parent | 6ad932c41281a6b7e525aba8a1057c09065251fa (diff) | |
download | numpy-8196c2a46fb621580a53ad5f7b2bd08cd154e870.tar.gz |
ENH: core: More informative error message for broadcast(*args)
When broadcast(*args) fails because of a shape mismatch, include
in the error message which arguments caused the mismatch and what
their shapes are.
For example, instead of
>>> np.broadcast([[0, 0, 0]], [[1], [1]], [2, 2])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: shape mismatch: objects cannot be broadcast to a
single shape
we now get
>>> np.broadcast([[0, 0, 0]], [[1], [1]], [2, 2])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: shape mismatch: objects cannot be broadcast to a
single shape. Mismatch is between arg 0 with shape (1, 3) and
arg 2 with shape (2,).
This also affects broadcast_arrays() and broadcast_shapes().
Closes gh-8345.
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 36 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 6 |
2 files changed, 38 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index 36bfaa7cf..f959162fd 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1124,6 +1124,35 @@ NPY_NO_EXPORT PyTypeObject PyArrayIter_Type = { /** END of Array Iterator **/ + +static int +set_shape_mismatch_exception(PyArrayMultiIterObject *mit, int i1, int i2) +{ + PyObject *shape1, *shape2, *msg; + + shape1 = PyObject_GetAttrString((PyObject *) mit->iters[i1]->ao, "shape"); + if (shape1 == NULL) { + return -1; + } + shape2 = PyObject_GetAttrString((PyObject *) mit->iters[i2]->ao, "shape"); + if (shape2 == NULL) { + Py_DECREF(shape1); + return -1; + } + msg = PyUnicode_FromFormat("shape mismatch: objects cannot be broadcast " + "to a single shape. Mismatch is between arg %d " + "with shape %S and arg %d with shape %S.", + i1, shape1, i2, shape2); + Py_DECREF(shape1); + Py_DECREF(shape2); + if (msg == NULL) { + return -1; + } + PyErr_SetObject(PyExc_ValueError, msg); + Py_DECREF(msg); + return 0; +} + /* Adjust dimensionality and strides for index object iterators --- i.e. broadcast */ @@ -1132,6 +1161,7 @@ NPY_NO_EXPORT int PyArray_Broadcast(PyArrayMultiIterObject *mit) { int i, nd, k, j; + int src_iter = -1; /* Initializing avoids a compiler warning. */ npy_intp tmp; PyArrayIterObject *it; @@ -1155,12 +1185,10 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit) } if (mit->dimensions[i] == 1) { mit->dimensions[i] = tmp; + src_iter = j; } else if (mit->dimensions[i] != tmp) { - PyErr_SetString(PyExc_ValueError, - "shape mismatch: objects" \ - " cannot be broadcast" \ - " to a single shape"); + set_shape_mismatch_exception(mit, src_iter, j); return -1; } } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 4510333a1..e36f76c53 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -3511,6 +3511,12 @@ class TestBroadcast: assert_raises(ValueError, np.broadcast, 1, **{'x': 1}) + def test_shape_mismatch_error_message(self): + with pytest.raises(ValueError, match=r"arg 0 with shape \(1, 3\) and " + r"arg 2 with shape \(2,\)"): + np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7]) + + class TestKeepdims: class sub_array(np.ndarray): |