summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2023-03-26 14:32:25 +0300
committerGitHub <noreply@github.com>2023-03-26 14:32:25 +0300
commita37978a106073eaec5cb9e0cb54785fafb639650 (patch)
treeac2ef9d4bbcf603f500051237f9c9ed318df9ff7 /numpy/array_api/_array_object.py
parent1e292ff08d3b797fa69b889c4a3fed99970308c8 (diff)
parent0bbe7dbf267cf835efcd514283815292bd94403f (diff)
downloadnumpy-a37978a106073eaec5cb9e0cb54785fafb639650.tar.gz
Merge pull request #21120 from BvB93/matmul
ENH: Add support for inplace matrix multiplication
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py14
1 files changed, 2 insertions, 12 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index eee117be6..a949b5977 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -850,23 +850,13 @@ class Array:
"""
Performs the operation __imatmul__.
"""
- # Note: NumPy does not implement __imatmul__.
-
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other
-
- # __imatmul__ can only be allowed when it would not change the shape
- # of self.
- other_shape = other.shape
- if self.shape == () or other_shape == ():
- raise ValueError("@= requires at least one dimension")
- if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
- raise ValueError("@= cannot change the shape of the input array")
- self._array[:] = self._array.__matmul__(other._array)
- return self
+ res = self._array.__imatmul__(other._array)
+ return self.__class__._new(res)
def __rmatmul__(self: Array, other: Array, /) -> Array:
"""