diff options
-rw-r--r-- | numpy/core/src/multiarray/number.c | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 13 |
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 3e7521582..953a84eef 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -405,6 +405,15 @@ array_matrix_multiply(PyArrayObject *m1, PyObject *m2) 0, nb_matrix_multiply); return PyArray_GenericBinaryFunction(m1, m2, matmul); } + +static PyObject * +array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) +{ + PyErr_SetString(PyExc_TypeError, + "In-place matrix multiplication is not (yet) supported. " + "Use 'a = a @ b' instead of 'a @= b'."); + return NULL; +} #endif /* Determine if object is a scalar and if so, convert the object @@ -1092,6 +1101,6 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = { (unaryfunc)array_index, /*nb_index */ #if PY_VERSION_HEX >= 0x03050000 (binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/ - (binaryfunc)NULL, /*nb_inplacematrix_multiply*/ + (binaryfunc)array_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/ #endif }; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ac645f013..9822d7dfc 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4322,6 +4322,19 @@ if sys.version_info[:2] >= (3, 5): assert_equal(self.matmul(a, b), "A") assert_equal(self.matmul(b, a), "A") + def test_matmul_inplace(): + # It would be nice to support in-place matmul eventually, but for now + # we don't have a working implementation, so better just to error out + # and nudge people to writing "a = a @ b". + a = np.eye(3) + b = np.eye(3) + assert_raises(TypeError, a.__imatmul__, b) + import operator + assert_raises(TypeError, operator.imatmul, a, b) + # we avoid writing the token `exec` so as not to crash python 2's + # parser + exec_ = getattr(builtins, "exec") + assert_raises(TypeError, exec_, "a @= b", globals(), locals()) class TestInner(TestCase): |