summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorMatthew <quitesimplymatt@gmail.com>2022-01-07 11:53:09 +0000
committerMatthew <quitesimplymatt@gmail.com>2022-01-07 12:27:33 +0000
commite3406ed3ef83f7de0f3419361e85bd8634d0fd2b (patch)
treef3817044a033b0fcbc84273161af9aff26c8d5e5 /numpy/array_api/_array_object.py
parentf4a3e07877eb258f69a7d65d8a3b7e8d96e1aacd (diff)
downloadnumpy-e3406ed3ef83f7de0f3419361e85bd8634d0fd2b.tar.gz
BUG: Allow integer inputs for pow-related functions in `array_api`
Updates `xp.power()`, `x.__pow__()`, `x.__ipow()__` and `x.__rpow()__`
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 04d9f8a24..c86cfb3a6 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -656,15 +656,13 @@ class Array:
res = self._array.__pos__()
return self.__class__._new(res)
- # PEP 484 requires int to be a subtype of float, but __pow__ should not
- # accept int.
- def __pow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
@@ -914,23 +912,23 @@ class Array:
res = self._array.__ror__(other._array)
return self.__class__._new(res)
- def __ipow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
- other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
self._array.__ipow__(other._array)
return self
- def __rpow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rpow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules