diff options
| author | Bas van Beek <b.f.van.beek@vu.nl> | 2022-02-25 16:31:04 +0100 |
|---|---|---|
| committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
| commit | 39744047cffcc88480a9ee2938417aa5f9eaeb27 (patch) | |
| tree | 43e5ee6172a6401e750b21df92ebf85be4fc5d6d /numpy | |
| parent | 01d64079b05c9b0c57775d0bb97d1cb5d52d2512 (diff) | |
| download | numpy-39744047cffcc88480a9ee2938417aa5f9eaeb27.tar.gz | |
ENH: Add support for inplace matrix multiplication
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/__init__.pyi | 14 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/number.c | 12 |
2 files changed, 19 insertions, 7 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 69ac47a76..0c7210d94 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -1928,7 +1928,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): def __neg__(self: NDArray[object_]) -> Any: ... # Binary ops - # NOTE: `ndarray` does not implement `__imatmul__` @overload def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc] @overload @@ -2515,6 +2514,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): @overload def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ... + @overload + def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... + @overload + def __imatmul__(self: NDArray[unsignedinteger[_NBit1]], other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[floating[_NBit1]], other: _ArrayLikeFloat_co) -> NDArray[floating[_NBit1]]: ... + @overload + def __imatmul__(self: NDArray[complexfloating[_NBit1, _NBit1]], other: _ArrayLikeComplex_co) -> NDArray[complexfloating[_NBit1, _NBit1]]: ... + @overload + def __imatmul__(self: NDArray[object_], other: Any) -> NDArray[object_]: ... + def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ... def __dlpack_device__(self) -> tuple[int, L[0]]: ... diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index df814a796..0c8b23bdf 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -53,6 +53,8 @@ static PyObject * array_inplace_remainder(PyArrayObject *m1, PyObject *m2); static PyObject * array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)); +static PyObject * +array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2); /* * Dictionary can contain any of the numeric operations, by name. @@ -348,13 +350,11 @@ array_matrix_multiply(PyObject *m1, PyObject *m2) } static PyObject * -array_inplace_matrix_multiply( - PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2)) +array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) { - PyErr_SetString(PyExc_TypeError, - "In-place matrix multiplication is not (yet) supported. " - "Use 'a = a @ b' instead of 'a @= b'."); - return NULL; + INPLACE_GIVE_UP_IF_NEEDED(m1, m2, + nb_inplace_matrix_multiply, array_inplace_matrix_multiply); + return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul); } /* |
