summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/number.c11
-rw-r--r--numpy/core/tests/test_multiarray.py13
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 3e7521582..953a84eef 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -405,6 +405,15 @@ array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
0, nb_matrix_multiply);
return PyArray_GenericBinaryFunction(m1, m2, matmul);
}
+
+static PyObject *
+array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
+{
+ PyErr_SetString(PyExc_TypeError,
+ "In-place matrix multiplication is not (yet) supported. "
+ "Use 'a = a @ b' instead of 'a @= b'.");
+ return NULL;
+}
#endif
/* Determine if object is a scalar and if so, convert the object
@@ -1092,6 +1101,6 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
(unaryfunc)array_index, /*nb_index */
#if PY_VERSION_HEX >= 0x03050000
(binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/
- (binaryfunc)NULL, /*nb_inplacematrix_multiply*/
+ (binaryfunc)array_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/
#endif
};
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ac645f013..9822d7dfc 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4322,6 +4322,19 @@ if sys.version_info[:2] >= (3, 5):
assert_equal(self.matmul(a, b), "A")
assert_equal(self.matmul(b, a), "A")
+ def test_matmul_inplace():
+ # It would be nice to support in-place matmul eventually, but for now
+ # we don't have a working implementation, so better just to error out
+ # and nudge people to writing "a = a @ b".
+ a = np.eye(3)
+ b = np.eye(3)
+ assert_raises(TypeError, a.__imatmul__, b)
+ import operator
+ assert_raises(TypeError, operator.imatmul, a, b)
+ # we avoid writing the token `exec` so as not to crash python 2's
+ # parser
+ exec_ = getattr(builtins, "exec")
+ assert_raises(TypeError, exec_, "a @= b", globals(), locals())
class TestInner(TestCase):