summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/array_api/_array_object.py14
-rw-r--r--numpy/array_api/_elementwise_functions.py4
-rw-r--r--numpy/array_api/tests/test_array_object.py2
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py2
4 files changed, 10 insertions, 12 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 04d9f8a24..c86cfb3a6 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -656,15 +656,13 @@ class Array:
res = self._array.__pos__()
return self.__class__._new(res)
- # PEP 484 requires int to be a subtype of float, but __pow__ should not
- # accept int.
- def __pow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
@@ -914,23 +912,23 @@ class Array:
res = self._array.__ror__(other._array)
return self.__class__._new(res)
- def __ipow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
- other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
self._array.__ipow__(other._array)
return self
- def __rpow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rpow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py
index 4408fe833..c758a0944 100644
--- a/numpy/array_api/_elementwise_functions.py
+++ b/numpy/array_api/_elementwise_functions.py
@@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
- if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError("Only floating-point dtypes are allowed in pow")
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in pow")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index b980bacca..1fe1dfddf 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -98,7 +98,7 @@ def test_operators():
"__mul__": "numeric",
"__ne__": "all",
"__or__": "integer_or_boolean",
- "__pow__": "floating",
+ "__pow__": "numeric",
"__rshift__": "integer",
"__sub__": "numeric",
"__truediv__": "floating",
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
index a9274aec9..b2fb44e76 100644
--- a/numpy/array_api/tests/test_elementwise_functions.py
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -66,7 +66,7 @@ def test_function_types():
"negative": "numeric",
"not_equal": "all",
"positive": "numeric",
- "pow": "floating-point",
+ "pow": "numeric",
"remainder": "numeric",
"round": "numeric",
"sign": "numeric",