diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-11-23 14:25:20 -0500 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-11-23 14:25:20 -0500 |
commit | c6adda6d86c19bd99ff4b1325a82db9536083c7c (patch) | |
tree | 8de9ffc910dd56b3373a35a24bc03373ffcf6e4e | |
parent | ed5f841bd3517334a173f8032e08b745d43c8033 (diff) | |
download | numpy-c6adda6d86c19bd99ff4b1325a82db9536083c7c.tar.gz |
BUG: don't override casting errors with matmul or inner
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 10 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 20 |
2 files changed, 19 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 22909ae1a..909a24359 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -833,7 +833,10 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2) typenum = PyArray_ObjectType(op2, typenum); typec = PyArray_DescrFromType(typenum); if (typec == NULL) { - PyErr_SetString(PyExc_TypeError, "Cannot find a common data type."); + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Cannot find a common data type."); + } goto fail; } @@ -2361,7 +2364,10 @@ array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds) dtype = PyArray_DescrFromObject(in1, NULL); dtype = PyArray_DescrFromObject(in2, dtype); if (dtype == NULL) { - PyErr_SetString(PyExc_ValueError, "Cannot find a common data type."); + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "Cannot find a common data type."); + } return NULL; } typenum = dtype->type_num; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 68fb6acf7..51fe6e9ef 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2693,15 +2693,6 @@ class TestMethods(object): assert_raises(TypeError, np.dot, c, A) assert_raises(TypeError, np.dot, A, c) - def test_dot_casting_fails(self): - - class A(object): - def __array__(self, *args, **kwargs): - raise NotImplementedError - - # Don't override the error from calling __array__() - assert_raises(NotImplementedError, np.dot, A(), A()) - def test_dot_out_mem_overlap(self): np.random.seed(1) @@ -2736,6 +2727,17 @@ class TestMethods(object): np.dot(a, b, out=out) np.matmul(a, b, out=out) + def test_dot_matmul_inner_array_casting_fails(self): + + class A(object): + def __array__(self, *args, **kwargs): + raise NotImplementedError + + # Don't override the error from calling __array__() + assert_raises(NotImplementedError, np.dot, A(), A()) + assert_raises(NotImplementedError, np.matmul, A(), A()) + assert_raises(NotImplementedError, np.inner, A(), A()) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |