summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2022-02-25 16:31:04 +0100
committerSebastian Berg <sebastianb@nvidia.com>2022-12-02 00:29:32 +0100
commit39744047cffcc88480a9ee2938417aa5f9eaeb27 (patch)
tree43e5ee6172a6401e750b21df92ebf85be4fc5d6d /numpy
parent01d64079b05c9b0c57775d0bb97d1cb5d52d2512 (diff)
downloadnumpy-39744047cffcc88480a9ee2938417aa5f9eaeb27.tar.gz
ENH: Add support for inplace matrix multiplication
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi14
-rw-r--r--numpy/core/src/multiarray/number.c12
2 files changed, 19 insertions, 7 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 69ac47a76..0c7210d94 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1928,7 +1928,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __neg__(self: NDArray[object_]) -> Any: ...
# Binary ops
- # NOTE: `ndarray` does not implement `__imatmul__`
@overload
def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
@@ -2515,6 +2514,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
+ @overload
+ def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...
+ @overload
+ def __imatmul__(self: NDArray[unsignedinteger[_NBit1]], other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co) -> NDArray[floating[_NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co) -> NDArray[complexfloating[_NBit1, _NBit1]]: ...
+ @overload
+ def __imatmul__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
+
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
def __dlpack_device__(self) -> tuple[int, L[0]]: ...
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index df814a796..0c8b23bdf 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -53,6 +53,8 @@ static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2);
static PyObject *
array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo));
+static PyObject *
+array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2);
/*
* Dictionary can contain any of the numeric operations, by name.
@@ -348,13 +350,11 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
}
static PyObject *
-array_inplace_matrix_multiply(
- PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2))
+array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
{
- PyErr_SetString(PyExc_TypeError,
- "In-place matrix multiplication is not (yet) supported. "
- "Use 'a = a @ b' instead of 'a @= b'.");
- return NULL;
+ INPLACE_GIVE_UP_IF_NEEDED(m1, m2,
+ nb_inplace_matrix_multiply, array_inplace_matrix_multiply);
+ return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul);
}
/*