From 5febef530e055572fd5eac18807675ee451c81b0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 9 Jul 2021 16:24:45 -0600 Subject: Only allow floating-point dtypes in the array API __pow__ and __truediv__ See https://github.com/data-apis/array-api/pull/221. --- numpy/_array_api/_array_object.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'numpy/_array_api') diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 2377bffe3..797f9ea4f 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -503,6 +503,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -548,6 +550,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -744,6 +748,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') self._array.__ipow__(other._array) return self @@ -756,6 +762,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -810,6 +818,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self._array.__itruediv__(other._array) return self @@ -820,6 +830,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) -- cgit v1.2.1