diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2022-02-25 18:05:39 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
commit | de0521fc22e641be5e819a2fec785c6f89ebca8c (patch) | |
tree | 34fa55ac8ed33916459942ff6c05f7674178a5e2 /numpy/array_api/_array_object.py | |
parent | e0ed8ceae87bcb6ac3ad9190ec409b3ed631429a (diff) | |
download | numpy-de0521fc22e641be5e819a2fec785c6f89ebca8c.tar.gz |
MAINT: Let `ndarray.__imatmul__` handle inplace matrix multiplication in the array-api
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r-- | numpy/array_api/_array_object.py | 14 |
1 files changed, 2 insertions, 12 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c4746fad9..592ca09df 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -850,23 +850,13 @@ class Array: """ Performs the operation __imatmul__. """ - # Note: NumPy does not implement __imatmul__. - # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - - # __imatmul__ can only be allowed when it would not change the shape - # of self. - other_shape = other.shape - if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") - if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: - raise ValueError("@= cannot change the shape of the input array") - self._array[:] = self._array.__matmul__(other._array) - return self + res = self._array.__imatmul__(other._array) + return self.__class__._new(res) def __rmatmul__(self: Array, other: Array, /) -> Array: """ |