summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-11-23 14:25:20 -0500
committerStephan Hoyer <shoyer@google.com>2018-11-23 14:25:20 -0500
commitc6adda6d86c19bd99ff4b1325a82db9536083c7c (patch)
tree8de9ffc910dd56b3373a35a24bc03373ffcf6e4e
parented5f841bd3517334a173f8032e08b745d43c8033 (diff)
downloadnumpy-c6adda6d86c19bd99ff4b1325a82db9536083c7c.tar.gz
BUG: don't override casting errors with matmul or inner
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c10
-rw-r--r--numpy/core/tests/test_multiarray.py20
2 files changed, 19 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 22909ae1a..909a24359 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -833,7 +833,10 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
typenum = PyArray_ObjectType(op2, typenum);
typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
- PyErr_SetString(PyExc_TypeError, "Cannot find a common data type.");
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Cannot find a common data type.");
+ }
goto fail;
}
@@ -2361,7 +2364,10 @@ array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
dtype = PyArray_DescrFromObject(in1, NULL);
dtype = PyArray_DescrFromObject(in2, dtype);
if (dtype == NULL) {
- PyErr_SetString(PyExc_ValueError, "Cannot find a common data type.");
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot find a common data type.");
+ }
return NULL;
}
typenum = dtype->type_num;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 68fb6acf7..51fe6e9ef 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2693,15 +2693,6 @@ class TestMethods(object):
assert_raises(TypeError, np.dot, c, A)
assert_raises(TypeError, np.dot, A, c)
- def test_dot_casting_fails(self):
-
- class A(object):
- def __array__(self, *args, **kwargs):
- raise NotImplementedError
-
- # Don't override the error from calling __array__()
- assert_raises(NotImplementedError, np.dot, A(), A())
-
def test_dot_out_mem_overlap(self):
np.random.seed(1)
@@ -2736,6 +2727,17 @@ class TestMethods(object):
np.dot(a, b, out=out)
np.matmul(a, b, out=out)
+ def test_dot_matmul_inner_array_casting_fails(self):
+
+ class A(object):
+ def __array__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ # Don't override the error from calling __array__()
+ assert_raises(NotImplementedError, np.dot, A(), A())
+ assert_raises(NotImplementedError, np.matmul, A(), A())
+ assert_raises(NotImplementedError, np.inner, A(), A())
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])