diff options
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/_array_object.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 2377bffe3..797f9ea4f 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -503,6 +503,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -548,6 +550,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -744,6 +748,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') self._array.__ipow__(other._array) return self @@ -756,6 +762,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -810,6 +818,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self._array.__itruediv__(other._array) return self @@ -820,6 +830,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) |
