From 13796236295b344ee83e79c8a33ad6205c0095db Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Jul 2021 16:10:27 -0600 Subject: Fix the __imatmul__ method in the array API namespace --- numpy/_array_api/_array_object.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'numpy/_array_api/_array_object.py') diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index a3de25478..8f7252160 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -648,12 +648,21 @@ class ndarray: """ Performs the operation __imatmul__. """ + # Note: NumPy does not implement __imatmul__. + if isinstance(other, (int, float, bool)): # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._promote_scalar(other) - res = self._array.__imatmul__(other._array) - return self.__class__._new(res) + # __imatmul__ can only be allowed when it would not change the shape + # of self. + other_shape = other.shape + if self.shape == () or other_shape == (): + raise ValueError("@= requires at least one dimension") + if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: + raise ValueError("@= cannot change the shape of the input array") + self._array[:] = self._array.__matmul__(other._array) + return self def __rmatmul__(self: array, other: array, /) -> array: """ -- cgit v1.2.1