summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-08-06 18:22:00 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-08-06 18:23:04 -0600
commit8f7d00ed447174d9398af3365709222b529c1cad (patch)
tree9de0a3a757a8c8a7393787ee1449e087c284d6e1 /numpy/array_api
parent21923a5fa71bfadf7dee0bb5b110cc2a5719eaac (diff)
downloadnumpy-8f7d00ed447174d9398af3365709222b529c1cad.tar.gz
Run (selective) black on the array_api submodule
I've omitted a few changes from black that messed up the readability of some complicated if statements that were organized logically line-by-line, and some changes that use unnecessary operator spacing.
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py247
-rw-r--r--numpy/array_api/_array_object.py205
-rw-r--r--numpy/array_api/_creation_functions.py158
-rw-r--r--numpy/array_api/_data_type_functions.py16
-rw-r--r--numpy/array_api/_dtypes.py75
-rw-r--r--numpy/array_api/_elementwise_functions.py194
-rw-r--r--numpy/array_api/_linear_algebra_functions.py16
-rw-r--r--numpy/array_api/_manipulation_functions.py18
-rw-r--r--numpy/array_api/_searching_functions.py4
-rw-r--r--numpy/array_api/_set_functions.py18
-rw-r--r--numpy/array_api/_sorting_functions.py14
-rw-r--r--numpy/array_api/_statistical_functions.py65
-rw-r--r--numpy/array_api/_typing.py30
-rw-r--r--numpy/array_api/_utility_functions.py18
-rw-r--r--numpy/array_api/setup.py10
-rw-r--r--numpy/array_api/tests/test_array_object.py101
-rw-r--r--numpy/array_api/tests/test_creation_functions.py122
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py133
18 files changed, 1054 insertions, 390 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 4dc931732..53c1f3850 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -119,36 +119,221 @@ Still TODO in this module are:
"""
import sys
+
# numpy.array_api is 3.8+ because it makes extensive use of positional-only
# arguments.
if sys.version_info < (3, 8):
raise ImportError("The numpy.array_api submodule requires Python 3.8 or greater.")
import warnings
-warnings.warn("The numpy.array_api submodule is still experimental. See NEP 47.",
- stacklevel=2)
+
+warnings.warn(
+ "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
+)
__all__ = []
from ._constants import e, inf, nan, pi
-__all__ += ['e', 'inf', 'nan', 'pi']
-
-from ._creation_functions import asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like
-
-__all__ += ['asarray', 'arange', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like']
-
-from ._data_type_functions import broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type
-
-__all__ += ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type']
-
-from ._dtypes import int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, bool
-
-__all__ += ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float32', 'float64', 'bool']
-
-from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc
-
-__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
+__all__ += ["e", "inf", "nan", "pi"]
+
+from ._creation_functions import (
+ asarray,
+ arange,
+ empty,
+ empty_like,
+ eye,
+ from_dlpack,
+ full,
+ full_like,
+ linspace,
+ meshgrid,
+ ones,
+ ones_like,
+ zeros,
+ zeros_like,
+)
+
+__all__ += [
+ "asarray",
+ "arange",
+ "empty",
+ "empty_like",
+ "eye",
+ "from_dlpack",
+ "full",
+ "full_like",
+ "linspace",
+ "meshgrid",
+ "ones",
+ "ones_like",
+ "zeros",
+ "zeros_like",
+]
+
+from ._data_type_functions import (
+ broadcast_arrays,
+ broadcast_to,
+ can_cast,
+ finfo,
+ iinfo,
+ result_type,
+)
+
+__all__ += [
+ "broadcast_arrays",
+ "broadcast_to",
+ "can_cast",
+ "finfo",
+ "iinfo",
+ "result_type",
+]
+
+from ._dtypes import (
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ bool,
+)
+
+__all__ += [
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "bool",
+]
+
+from ._elementwise_functions import (
+ abs,
+ acos,
+ acosh,
+ add,
+ asin,
+ asinh,
+ atan,
+ atan2,
+ atanh,
+ bitwise_and,
+ bitwise_left_shift,
+ bitwise_invert,
+ bitwise_or,
+ bitwise_right_shift,
+ bitwise_xor,
+ ceil,
+ cos,
+ cosh,
+ divide,
+ equal,
+ exp,
+ expm1,
+ floor,
+ floor_divide,
+ greater,
+ greater_equal,
+ isfinite,
+ isinf,
+ isnan,
+ less,
+ less_equal,
+ log,
+ log1p,
+ log2,
+ log10,
+ logaddexp,
+ logical_and,
+ logical_not,
+ logical_or,
+ logical_xor,
+ multiply,
+ negative,
+ not_equal,
+ positive,
+ pow,
+ remainder,
+ round,
+ sign,
+ sin,
+ sinh,
+ square,
+ sqrt,
+ subtract,
+ tan,
+ tanh,
+ trunc,
+)
+
+__all__ += [
+ "abs",
+ "acos",
+ "acosh",
+ "add",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bitwise_and",
+ "bitwise_left_shift",
+ "bitwise_invert",
+ "bitwise_or",
+ "bitwise_right_shift",
+ "bitwise_xor",
+ "ceil",
+ "cos",
+ "cosh",
+ "divide",
+ "equal",
+ "exp",
+ "expm1",
+ "floor",
+ "floor_divide",
+ "greater",
+ "greater_equal",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "less",
+ "less_equal",
+ "log",
+ "log1p",
+ "log2",
+ "log10",
+ "logaddexp",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "multiply",
+ "negative",
+ "not_equal",
+ "positive",
+ "pow",
+ "remainder",
+ "round",
+ "sign",
+ "sin",
+ "sinh",
+ "square",
+ "sqrt",
+ "subtract",
+ "tan",
+ "tanh",
+ "trunc",
+]
# einsum is not yet implemented in the array API spec.
@@ -157,28 +342,36 @@ __all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'at
from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
-__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot']
+__all__ += ["matmul", "tensordot", "transpose", "vecdot"]
-from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
+from ._manipulation_functions import (
+ concat,
+ expand_dims,
+ flip,
+ reshape,
+ roll,
+ squeeze,
+ stack,
+)
-__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
+__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"]
from ._searching_functions import argmax, argmin, nonzero, where
-__all__ += ['argmax', 'argmin', 'nonzero', 'where']
+__all__ += ["argmax", "argmin", "nonzero", "where"]
from ._set_functions import unique
-__all__ += ['unique']
+__all__ += ["unique"]
from ._sorting_functions import argsort, sort
-__all__ += ['argsort', 'sort']
+__all__ += ["argsort", "sort"]
from ._statistical_functions import max, mean, min, prod, std, sum, var
-__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']
+__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
from ._utility_functions import all, any
-__all__ += ['all', 'any']
+__all__ += ["all", "any"]
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 00f50eade..0f511a577 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -18,11 +18,19 @@ from __future__ import annotations
import operator
from enum import IntEnum
from ._creation_functions import asarray
-from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes,
- _integer_or_boolean_dtypes, _floating_dtypes,
- _numeric_dtypes, _result_type, _dtype_categories)
+from ._dtypes import (
+ _all_dtypes,
+ _boolean_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _floating_dtypes,
+ _numeric_dtypes,
+ _result_type,
+ _dtype_categories,
+)
from typing import TYPE_CHECKING, Optional, Tuple, Union
+
if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype
@@ -30,6 +38,7 @@ import numpy as np
from numpy import array_api
+
class Array:
"""
n-d array object for the array API namespace.
@@ -45,6 +54,7 @@ class Array:
functions, such as asarray().
"""
+
# Use a custom constructor instead of __init__, as manually initializing
# this class is not supported API.
@classmethod
@@ -64,13 +74,17 @@ class Array:
# Convert the array scalar to a 0-D array
x = np.asarray(x)
if x.dtype not in _all_dtypes:
- raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'")
+ raise TypeError(
+ f"The array_api namespace does not support the dtype '{x.dtype}'"
+ )
obj._array = x
return obj
# Prevent Array() from working
def __new__(cls, *args, **kwargs):
- raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.")
+ raise TypeError(
+ "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
+ )
# These functions are not required by the spec, but are implemented for
# the sake of usability.
@@ -79,7 +93,7 @@ class Array:
"""
Performs the operation __str__.
"""
- return self._array.__str__().replace('array', 'Array')
+ return self._array.__str__().replace("array", "Array")
def __repr__(self: Array, /) -> str:
"""
@@ -103,12 +117,12 @@ class Array:
"""
if self.dtype not in _dtype_categories[dtype_category]:
- raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
if other.dtype not in _dtype_categories[dtype_category]:
- raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
+ raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
else:
return NotImplemented
@@ -116,7 +130,7 @@ class Array:
# to promote in the spec (even if the NumPy array operator would
# promote them).
res_dtype = _result_type(self.dtype, other.dtype)
- if op.startswith('__i'):
+ if op.startswith("__i"):
# Note: NumPy will allow in-place operators in some cases where
# the type promoted operator does not match the left-hand side
# operand. For example,
@@ -126,7 +140,9 @@ class Array:
# The spec explicitly disallows this.
if res_dtype != self.dtype:
- raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}")
+ raise TypeError(
+ f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}"
+ )
return other
@@ -142,13 +158,19 @@ class Array:
"""
if isinstance(scalar, bool):
if self.dtype not in _boolean_dtypes:
- raise TypeError("Python bool scalars can only be promoted with bool arrays")
+ raise TypeError(
+ "Python bool scalars can only be promoted with bool arrays"
+ )
elif isinstance(scalar, int):
if self.dtype in _boolean_dtypes:
- raise TypeError("Python int scalars cannot be promoted with bool arrays")
+ raise TypeError(
+ "Python int scalars cannot be promoted with bool arrays"
+ )
elif isinstance(scalar, float):
if self.dtype not in _floating_dtypes:
- raise TypeError("Python float scalars can only be promoted with floating-point arrays.")
+ raise TypeError(
+ "Python float scalars can only be promoted with floating-point arrays."
+ )
else:
raise TypeError("'scalar' must be a Python scalar")
@@ -253,7 +275,9 @@ class Array:
except TypeError:
return key
if not (-size <= key.start <= max(0, size - 1)):
- raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace")
+ raise IndexError(
+ "Slices with out-of-bounds start are not allowed in the array API namespace"
+ )
if key.stop is not None:
try:
operator.index(key.stop)
@@ -269,12 +293,20 @@ class Array:
key = tuple(Array._validate_index(idx, None) for idx in key)
for idx in key:
- if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)):
+ if (
+ isinstance(idx, np.ndarray)
+ and idx.dtype in _boolean_dtypes
+ or isinstance(idx, (bool, np.bool_))
+ ):
if len(key) == 1:
return key
- raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace")
+ raise IndexError(
+ "Boolean array indices combined with other indices are not allowed in the array API namespace"
+ )
if isinstance(idx, tuple):
- raise IndexError("Nested tuple indices are not allowed in the array API namespace")
+ raise IndexError(
+ "Nested tuple indices are not allowed in the array API namespace"
+ )
if shape is None:
return key
@@ -283,7 +315,9 @@ class Array:
return key
ellipsis_i = key.index(...) if n_ellipsis else len(key)
- for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])):
+ for idx, size in list(zip(key[:ellipsis_i], shape)) + list(
+ zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
+ ):
Array._validate_index(idx, (size,))
return key
elif isinstance(key, bool):
@@ -291,18 +325,24 @@ class Array:
elif isinstance(key, Array):
if key.dtype in _integer_dtypes:
if key.ndim != 0:
- raise IndexError("Non-zero dimensional integer array indices are not allowed in the array API namespace")
+ raise IndexError(
+ "Non-zero dimensional integer array indices are not allowed in the array API namespace"
+ )
return key._array
elif key is Ellipsis:
return key
elif key is None:
- raise IndexError("newaxis indices are not allowed in the array API namespace")
+ raise IndexError(
+ "newaxis indices are not allowed in the array API namespace"
+ )
try:
return operator.index(key)
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
- raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace")
+ raise IndexError(
+ "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace"
+ )
# Everything below this line is required by the spec.
@@ -311,7 +351,7 @@ class Array:
Performs the operation __abs__.
"""
if self.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in __abs__')
+ raise TypeError("Only numeric dtypes are allowed in __abs__")
res = self._array.__abs__()
return self.__class__._new(res)
@@ -319,7 +359,7 @@ class Array:
"""
Performs the operation __add__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -330,15 +370,17 @@ class Array:
"""
Performs the operation __and__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__and__(other._array)
return self.__class__._new(res)
- def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object:
- if api_version is not None and not api_version.startswith('2021.'):
+ def __array_namespace__(
+ self: Array, /, *, api_version: Optional[str] = None
+ ) -> object:
+ if api_version is not None and not api_version.startswith("2021."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return array_api
@@ -373,7 +415,7 @@ class Array:
"""
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
- other = self._check_allowed_dtypes(other, 'all', '__eq__')
+ other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -394,7 +436,7 @@ class Array:
"""
Performs the operation __floordiv__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__')
+ other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -405,14 +447,20 @@ class Array:
"""
Performs the operation __ge__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__ge__')
+ other = self._check_allowed_dtypes(other, "numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ge__(other._array)
return self.__class__._new(res)
- def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array:
+ def __getitem__(
+ self: Array,
+ key: Union[
+ int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array
+ ],
+ /,
+ ) -> Array:
"""
Performs the operation __getitem__.
"""
@@ -426,7 +474,7 @@ class Array:
"""
Performs the operation __gt__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__gt__')
+ other = self._check_allowed_dtypes(other, "numeric", "__gt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -448,7 +496,7 @@ class Array:
Performs the operation __invert__.
"""
if self.dtype not in _integer_or_boolean_dtypes:
- raise TypeError('Only integer or boolean dtypes are allowed in __invert__')
+ raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
res = self._array.__invert__()
return self.__class__._new(res)
@@ -456,7 +504,7 @@ class Array:
"""
Performs the operation __le__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__le__')
+ other = self._check_allowed_dtypes(other, "numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -474,7 +522,7 @@ class Array:
"""
Performs the operation __lshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__lshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -485,7 +533,7 @@ class Array:
"""
Performs the operation __lt__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__lt__')
+ other = self._check_allowed_dtypes(other, "numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -498,7 +546,7 @@ class Array:
"""
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
- other = self._check_allowed_dtypes(other, 'numeric', '__matmul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
if other is NotImplemented:
return other
res = self._array.__matmul__(other._array)
@@ -508,7 +556,7 @@ class Array:
"""
Performs the operation __mod__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__mod__')
+ other = self._check_allowed_dtypes(other, "numeric", "__mod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -519,7 +567,7 @@ class Array:
"""
Performs the operation __mul__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__mul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -530,7 +578,7 @@ class Array:
"""
Performs the operation __ne__.
"""
- other = self._check_allowed_dtypes(other, 'all', '__ne__')
+ other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -542,7 +590,7 @@ class Array:
Performs the operation __neg__.
"""
if self.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in __neg__')
+ raise TypeError("Only numeric dtypes are allowed in __neg__")
res = self._array.__neg__()
return self.__class__._new(res)
@@ -550,7 +598,7 @@ class Array:
"""
Performs the operation __or__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -562,7 +610,7 @@ class Array:
Performs the operation __pos__.
"""
if self.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in __pos__')
+ raise TypeError("Only numeric dtypes are allowed in __pos__")
res = self._array.__pos__()
return self.__class__._new(res)
@@ -574,7 +622,7 @@ class Array:
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, 'floating-point', '__pow__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
@@ -585,14 +633,21 @@ class Array:
"""
Performs the operation __rshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__rshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__rshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__rshift__(other._array)
return self.__class__._new(res)
- def __setitem__(self, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], value: Union[int, float, bool, Array], /) -> Array:
+ def __setitem__(
+ self,
+ key: Union[
+ int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array
+ ],
+ value: Union[int, float, bool, Array],
+ /,
+ ) -> Array:
"""
Performs the operation __setitem__.
"""
@@ -605,7 +660,7 @@ class Array:
"""
Performs the operation __sub__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__sub__')
+ other = self._check_allowed_dtypes(other, "numeric", "__sub__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -618,7 +673,7 @@ class Array:
"""
Performs the operation __truediv__.
"""
- other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -629,7 +684,7 @@ class Array:
"""
Performs the operation __xor__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -640,7 +695,7 @@ class Array:
"""
Performs the operation __iadd__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__iadd__')
+ other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
if other is NotImplemented:
return other
self._array.__iadd__(other._array)
@@ -650,7 +705,7 @@ class Array:
"""
Performs the operation __radd__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__radd__')
+ other = self._check_allowed_dtypes(other, "numeric", "__radd__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -661,7 +716,7 @@ class Array:
"""
Performs the operation __iand__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
if other is NotImplemented:
return other
self._array.__iand__(other._array)
@@ -671,7 +726,7 @@ class Array:
"""
Performs the operation __rand__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -682,7 +737,7 @@ class Array:
"""
Performs the operation __ifloordiv__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__')
+ other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
if other is NotImplemented:
return other
self._array.__ifloordiv__(other._array)
@@ -692,7 +747,7 @@ class Array:
"""
Performs the operation __rfloordiv__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__')
+ other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -703,7 +758,7 @@ class Array:
"""
Performs the operation __ilshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__ilshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
if other is NotImplemented:
return other
self._array.__ilshift__(other._array)
@@ -713,7 +768,7 @@ class Array:
"""
Performs the operation __rlshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__rlshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -728,7 +783,7 @@ class Array:
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
- other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other
@@ -748,7 +803,7 @@ class Array:
"""
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
- other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
if other is NotImplemented:
return other
res = self._array.__rmatmul__(other._array)
@@ -758,7 +813,7 @@ class Array:
"""
Performs the operation __imod__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__imod__')
+ other = self._check_allowed_dtypes(other, "numeric", "__imod__")
if other is NotImplemented:
return other
self._array.__imod__(other._array)
@@ -768,7 +823,7 @@ class Array:
"""
Performs the operation __rmod__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__rmod__')
+ other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -779,7 +834,7 @@ class Array:
"""
Performs the operation __imul__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__imul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__imul__")
if other is NotImplemented:
return other
self._array.__imul__(other._array)
@@ -789,7 +844,7 @@ class Array:
"""
Performs the operation __rmul__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__rmul__')
+ other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -800,7 +855,7 @@ class Array:
"""
Performs the operation __ior__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
if other is NotImplemented:
return other
self._array.__ior__(other._array)
@@ -810,7 +865,7 @@ class Array:
"""
Performs the operation __ror__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -821,7 +876,7 @@ class Array:
"""
Performs the operation __ipow__.
"""
- other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
if other is NotImplemented:
return other
self._array.__ipow__(other._array)
@@ -833,7 +888,7 @@ class Array:
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
@@ -844,7 +899,7 @@ class Array:
"""
Performs the operation __irshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__irshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__irshift__")
if other is NotImplemented:
return other
self._array.__irshift__(other._array)
@@ -854,7 +909,7 @@ class Array:
"""
Performs the operation __rrshift__.
"""
- other = self._check_allowed_dtypes(other, 'integer', '__rrshift__')
+ other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -865,7 +920,7 @@ class Array:
"""
Performs the operation __isub__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__isub__')
+ other = self._check_allowed_dtypes(other, "numeric", "__isub__")
if other is NotImplemented:
return other
self._array.__isub__(other._array)
@@ -875,7 +930,7 @@ class Array:
"""
Performs the operation __rsub__.
"""
- other = self._check_allowed_dtypes(other, 'numeric', '__rsub__')
+ other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -886,7 +941,7 @@ class Array:
"""
Performs the operation __itruediv__.
"""
- other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
if other is NotImplemented:
return other
self._array.__itruediv__(other._array)
@@ -896,7 +951,7 @@ class Array:
"""
Performs the operation __rtruediv__.
"""
- other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__')
+ other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -907,7 +962,7 @@ class Array:
"""
Performs the operation __ixor__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
if other is NotImplemented:
return other
self._array.__ixor__(other._array)
@@ -917,7 +972,7 @@ class Array:
"""
Performs the operation __rxor__.
"""
- other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__')
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
@@ -935,7 +990,7 @@ class Array:
@property
def device(self) -> Device:
- return 'cpu'
+ return "cpu"
@property
def ndim(self) -> int:
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index acf78056a..e9c01e7e6 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -2,14 +2,22 @@ from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
if TYPE_CHECKING:
- from ._typing import (Array, Device, Dtype, NestedSequence,
- SupportsDLPack, SupportsBufferProtocol)
+ from ._typing import (
+ Array,
+ Device,
+ Dtype,
+ NestedSequence,
+ SupportsDLPack,
+ SupportsBufferProtocol,
+ )
from collections.abc import Sequence
from ._dtypes import _all_dtypes
import numpy as np
+
def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -20,7 +28,23 @@ def _check_valid_dtype(dtype):
return
raise ValueError("dtype must be one of the supported dtypes")
-def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
+
+def asarray(
+ obj: Union[
+ Array,
+ bool,
+ int,
+ float,
+ NestedSequence[bool | int | float],
+ SupportsDLPack,
+ SupportsBufferProtocol,
+ ],
+ /,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ copy: Optional[bool] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -31,7 +55,7 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float],
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
if copy is False:
# Note: copy=False is not yet implemented in np.asarray
@@ -40,14 +64,23 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float],
if copy is True:
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
- if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63):
+ if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
res = np.asarray(obj, dtype=dtype)
return Array._new(res)
-def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def arange(
+ start: Union[int, float],
+ /,
+ stop: Optional[Union[int, float]] = None,
+ step: Union[int, float] = 1,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
@@ -56,11 +89,17 @@ def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
-def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def empty(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
@@ -69,11 +108,14 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None,
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty(shape, dtype=dtype))
-def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def empty_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
@@ -82,11 +124,20 @@ def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty_like(x._array, dtype=dtype))
-def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def eye(
+ n_rows: int,
+ n_cols: Optional[int] = None,
+ /,
+ *,
+ k: Optional[int] = 0,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
@@ -95,15 +146,23 @@ def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, d
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
+
def from_dlpack(x: object, /) -> Array:
# Note: dlpack support is not yet implemented on Array
raise NotImplementedError("DLPack support is not yet implemented")
-def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def full(
+ shape: Union[int, Tuple[int, ...]],
+ fill_value: Union[int, float],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
@@ -112,7 +171,7 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
if isinstance(fill_value, Array) and fill_value.ndim == 0:
fill_value = fill_value._array
@@ -123,7 +182,15 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d
raise TypeError("Invalid input to full")
return Array._new(res)
-def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def full_like(
+ x: Array,
+ /,
+ fill_value: Union[int, float],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
@@ -132,7 +199,7 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = np.full_like(x._array, fill_value, dtype=dtype)
if res.dtype not in _all_dtypes:
@@ -141,7 +208,17 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty
raise TypeError("Invalid input to full_like")
return Array._new(res)
-def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array:
+
+def linspace(
+ start: Union[int, float],
+ stop: Union[int, float],
+ /,
+ num: int,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ endpoint: bool = True,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
@@ -150,20 +227,31 @@ def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *,
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
-def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]:
+
+def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
See its docstring for more information.
"""
from ._array_object import Array
- return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)]
-def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+ return [
+ Array._new(array)
+ for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
+ ]
+
+
+def ones(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
@@ -172,11 +260,14 @@ def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, d
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones(shape, dtype=dtype))
-def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def ones_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
@@ -185,11 +276,17 @@ def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[De
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones_like(x._array, dtype=dtype))
-def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def zeros(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
@@ -198,11 +295,14 @@ def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None,
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros(shape, dtype=dtype))
-def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
+
+def zeros_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
@@ -211,6 +311,6 @@ def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D
from ._array_object import Array
_check_valid_dtype(dtype)
- if device not in ['cpu', None]:
+ if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros_like(x._array, dtype=dtype))
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 17a00cc6d..e6121a8a4 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -5,12 +5,14 @@ from ._dtypes import _all_dtypes, _result_type
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Tuple, Union
+
if TYPE_CHECKING:
from ._typing import Dtype
from collections.abc import Sequence
import numpy as np
+
def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
@@ -18,7 +20,11 @@ def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]:
See its docstring for more information.
"""
from ._array_object import Array
- return [Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])]
+
+ return [
+ Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])
+ ]
+
def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
@@ -27,8 +33,10 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
See its docstring for more information.
"""
from ._array_object import Array
+
return Array._new(np.broadcast_to(x._array, shape))
+
def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
"""
Array API compatible wrapper for :py:func:`np.can_cast <numpy.can_cast>`.
@@ -36,10 +44,12 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
See its docstring for more information.
"""
from ._array_object import Array
+
if isinstance(from_, Array):
from_ = from_._array
return np.can_cast(from_, to)
+
# These are internal objects for the return types of finfo and iinfo, since
# the NumPy versions contain extra data that isn't part of the spec.
@dataclass
@@ -55,12 +65,14 @@ class finfo_object:
# smallest_normal: float
+
@dataclass
class iinfo_object:
bits: int
max: int
min: int
+
def finfo(type: Union[Dtype, Array], /) -> finfo_object:
"""
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
@@ -79,6 +91,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
# float(fi.smallest_normal),
)
+
def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
"""
Array API compatible wrapper for :py:func:`np.iinfo <numpy.iinfo>`.
@@ -88,6 +101,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min)
+
def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py
index 07be267da..476d619fe 100644
--- a/numpy/array_api/_dtypes.py
+++ b/numpy/array_api/_dtypes.py
@@ -2,34 +2,66 @@ import numpy as np
# Note: we use dtype objects instead of dtype classes. The spec does not
# require any behavior on dtypes other than equality.
-int8 = np.dtype('int8')
-int16 = np.dtype('int16')
-int32 = np.dtype('int32')
-int64 = np.dtype('int64')
-uint8 = np.dtype('uint8')
-uint16 = np.dtype('uint16')
-uint32 = np.dtype('uint32')
-uint64 = np.dtype('uint64')
-float32 = np.dtype('float32')
-float64 = np.dtype('float64')
+int8 = np.dtype("int8")
+int16 = np.dtype("int16")
+int32 = np.dtype("int32")
+int64 = np.dtype("int64")
+uint8 = np.dtype("uint8")
+uint16 = np.dtype("uint16")
+uint32 = np.dtype("uint32")
+uint64 = np.dtype("uint64")
+float32 = np.dtype("float32")
+float64 = np.dtype("float64")
# Note: This name is changed
-bool = np.dtype('bool')
+bool = np.dtype("bool")
-_all_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64,
- float32, float64, bool)
+_all_dtypes = (
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ bool,
+)
_boolean_dtypes = (bool,)
_floating_dtypes = (float32, float64)
_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
-_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
-_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
+_integer_or_boolean_dtypes = (
+ bool,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+)
+_numeric_dtypes = (
+ float32,
+ float64,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+)
_dtype_categories = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer or boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating-point': _floating_dtypes,
+ "all": _all_dtypes,
+ "numeric": _numeric_dtypes,
+ "integer": _integer_dtypes,
+ "integer or boolean": _integer_or_boolean_dtypes,
+ "boolean": _boolean_dtypes,
+ "floating-point": _floating_dtypes,
}
@@ -104,6 +136,7 @@ _promotion_table = {
(bool, bool): bool,
}
+
def _result_type(type1, type2):
if (type1, type2) in _promotion_table:
return _promotion_table[type1, type2]
diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py
index 7833ebe54..4408fe833 100644
--- a/numpy/array_api/_elementwise_functions.py
+++ b/numpy/array_api/_elementwise_functions.py
@@ -1,12 +1,18 @@
from __future__ import annotations
-from ._dtypes import (_boolean_dtypes, _floating_dtypes,
- _integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes, _result_type)
+from ._dtypes import (
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _numeric_dtypes,
+ _result_type,
+)
from ._array_object import Array
import numpy as np
+
def abs(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
@@ -14,9 +20,10 @@ def abs(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in abs')
+ raise TypeError("Only numeric dtypes are allowed in abs")
return Array._new(np.abs(x._array))
+
# Note: the function name is different here
def acos(x: Array, /) -> Array:
"""
@@ -25,9 +32,10 @@ def acos(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in acos')
+ raise TypeError("Only floating-point dtypes are allowed in acos")
return Array._new(np.arccos(x._array))
+
# Note: the function name is different here
def acosh(x: Array, /) -> Array:
"""
@@ -36,9 +44,10 @@ def acosh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in acosh')
+ raise TypeError("Only floating-point dtypes are allowed in acosh")
return Array._new(np.arccosh(x._array))
+
def add(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
@@ -46,12 +55,13 @@ def add(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in add')
+ raise TypeError("Only numeric dtypes are allowed in add")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.add(x1._array, x2._array))
+
# Note: the function name is different here
def asin(x: Array, /) -> Array:
"""
@@ -60,9 +70,10 @@ def asin(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in asin')
+ raise TypeError("Only floating-point dtypes are allowed in asin")
return Array._new(np.arcsin(x._array))
+
# Note: the function name is different here
def asinh(x: Array, /) -> Array:
"""
@@ -71,9 +82,10 @@ def asinh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in asinh')
+ raise TypeError("Only floating-point dtypes are allowed in asinh")
return Array._new(np.arcsinh(x._array))
+
# Note: the function name is different here
def atan(x: Array, /) -> Array:
"""
@@ -82,9 +94,10 @@ def atan(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in atan')
+ raise TypeError("Only floating-point dtypes are allowed in atan")
return Array._new(np.arctan(x._array))
+
# Note: the function name is different here
def atan2(x1: Array, x2: Array, /) -> Array:
"""
@@ -93,12 +106,13 @@ def atan2(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in atan2')
+ raise TypeError("Only floating-point dtypes are allowed in atan2")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.arctan2(x1._array, x2._array))
+
# Note: the function name is different here
def atanh(x: Array, /) -> Array:
"""
@@ -107,22 +121,27 @@ def atanh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in atanh')
+ raise TypeError("Only floating-point dtypes are allowed in atanh")
return Array._new(np.arctanh(x._array))
+
def bitwise_and(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
See its docstring for more information.
"""
- if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
- raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and')
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_and(x1._array, x2._array))
+
# Note: the function name is different here
def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
"""
@@ -131,15 +150,16 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
- raise TypeError('Only integer dtypes are allowed in bitwise_left_shift')
+ raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
# Note: bitwise_left_shift is only defined for x2 nonnegative.
if np.any(x2._array < 0):
- raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0')
+ raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
return Array._new(np.left_shift(x1._array, x2._array))
+
# Note: the function name is different here
def bitwise_invert(x: Array, /) -> Array:
"""
@@ -148,22 +168,27 @@ def bitwise_invert(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _integer_or_boolean_dtypes:
- raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert')
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert")
return Array._new(np.invert(x._array))
+
def bitwise_or(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
See its docstring for more information.
"""
- if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
- raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or')
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_or(x1._array, x2._array))
+
# Note: the function name is different here
def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
"""
@@ -172,28 +197,33 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
- raise TypeError('Only integer dtypes are allowed in bitwise_right_shift')
+ raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
# Note: bitwise_right_shift is only defined for x2 nonnegative.
if np.any(x2._array < 0):
- raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0')
+ raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0")
return Array._new(np.right_shift(x1._array, x2._array))
+
def bitwise_xor(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
See its docstring for more information.
"""
- if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
- raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor')
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_xor(x1._array, x2._array))
+
def ceil(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
@@ -201,12 +231,13 @@ def ceil(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in ceil')
+ raise TypeError("Only numeric dtypes are allowed in ceil")
if x.dtype in _integer_dtypes:
# Note: The return dtype of ceil is the same as the input
return x
return Array._new(np.ceil(x._array))
+
def cos(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
@@ -214,9 +245,10 @@ def cos(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in cos')
+ raise TypeError("Only floating-point dtypes are allowed in cos")
return Array._new(np.cos(x._array))
+
def cosh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
@@ -224,9 +256,10 @@ def cosh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in cosh')
+ raise TypeError("Only floating-point dtypes are allowed in cosh")
return Array._new(np.cosh(x._array))
+
def divide(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
@@ -234,12 +267,13 @@ def divide(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in divide')
+ raise TypeError("Only floating-point dtypes are allowed in divide")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.divide(x1._array, x2._array))
+
def equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
@@ -251,6 +285,7 @@ def equal(x1: Array, x2: Array, /) -> Array:
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.equal(x1._array, x2._array))
+
def exp(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
@@ -258,9 +293,10 @@ def exp(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in exp')
+ raise TypeError("Only floating-point dtypes are allowed in exp")
return Array._new(np.exp(x._array))
+
def expm1(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
@@ -268,9 +304,10 @@ def expm1(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in expm1')
+ raise TypeError("Only floating-point dtypes are allowed in expm1")
return Array._new(np.expm1(x._array))
+
def floor(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
@@ -278,12 +315,13 @@ def floor(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in floor')
+ raise TypeError("Only numeric dtypes are allowed in floor")
if x.dtype in _integer_dtypes:
# Note: The return dtype of floor is the same as the input
return x
return Array._new(np.floor(x._array))
+
def floor_divide(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
@@ -291,12 +329,13 @@ def floor_divide(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in floor_divide')
+ raise TypeError("Only numeric dtypes are allowed in floor_divide")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.floor_divide(x1._array, x2._array))
+
def greater(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
@@ -304,12 +343,13 @@ def greater(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in greater')
+ raise TypeError("Only numeric dtypes are allowed in greater")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater(x1._array, x2._array))
+
def greater_equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
@@ -317,12 +357,13 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in greater_equal')
+ raise TypeError("Only numeric dtypes are allowed in greater_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater_equal(x1._array, x2._array))
+
def isfinite(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
@@ -330,9 +371,10 @@ def isfinite(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in isfinite')
+ raise TypeError("Only numeric dtypes are allowed in isfinite")
return Array._new(np.isfinite(x._array))
+
def isinf(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
@@ -340,9 +382,10 @@ def isinf(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in isinf')
+ raise TypeError("Only numeric dtypes are allowed in isinf")
return Array._new(np.isinf(x._array))
+
def isnan(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
@@ -350,9 +393,10 @@ def isnan(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in isnan')
+ raise TypeError("Only numeric dtypes are allowed in isnan")
return Array._new(np.isnan(x._array))
+
def less(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
@@ -360,12 +404,13 @@ def less(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in less')
+ raise TypeError("Only numeric dtypes are allowed in less")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less(x1._array, x2._array))
+
def less_equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
@@ -373,12 +418,13 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in less_equal')
+ raise TypeError("Only numeric dtypes are allowed in less_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less_equal(x1._array, x2._array))
+
def log(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
@@ -386,9 +432,10 @@ def log(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in log')
+ raise TypeError("Only floating-point dtypes are allowed in log")
return Array._new(np.log(x._array))
+
def log1p(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
@@ -396,9 +443,10 @@ def log1p(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in log1p')
+ raise TypeError("Only floating-point dtypes are allowed in log1p")
return Array._new(np.log1p(x._array))
+
def log2(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
@@ -406,9 +454,10 @@ def log2(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in log2')
+ raise TypeError("Only floating-point dtypes are allowed in log2")
return Array._new(np.log2(x._array))
+
def log10(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
@@ -416,9 +465,10 @@ def log10(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in log10')
+ raise TypeError("Only floating-point dtypes are allowed in log10")
return Array._new(np.log10(x._array))
+
def logaddexp(x1: Array, x2: Array) -> Array:
"""
Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
@@ -426,12 +476,13 @@ def logaddexp(x1: Array, x2: Array) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in logaddexp')
+ raise TypeError("Only floating-point dtypes are allowed in logaddexp")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logaddexp(x1._array, x2._array))
+
def logical_and(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
@@ -439,12 +490,13 @@ def logical_and(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
- raise TypeError('Only boolean dtypes are allowed in logical_and')
+ raise TypeError("Only boolean dtypes are allowed in logical_and")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_and(x1._array, x2._array))
+
def logical_not(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
@@ -452,9 +504,10 @@ def logical_not(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _boolean_dtypes:
- raise TypeError('Only boolean dtypes are allowed in logical_not')
+ raise TypeError("Only boolean dtypes are allowed in logical_not")
return Array._new(np.logical_not(x._array))
+
def logical_or(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
@@ -462,12 +515,13 @@ def logical_or(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
- raise TypeError('Only boolean dtypes are allowed in logical_or')
+ raise TypeError("Only boolean dtypes are allowed in logical_or")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_or(x1._array, x2._array))
+
def logical_xor(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
@@ -475,12 +529,13 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
- raise TypeError('Only boolean dtypes are allowed in logical_xor')
+ raise TypeError("Only boolean dtypes are allowed in logical_xor")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_xor(x1._array, x2._array))
+
def multiply(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
@@ -488,12 +543,13 @@ def multiply(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in multiply')
+ raise TypeError("Only numeric dtypes are allowed in multiply")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.multiply(x1._array, x2._array))
+
def negative(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
@@ -501,9 +557,10 @@ def negative(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in negative')
+ raise TypeError("Only numeric dtypes are allowed in negative")
return Array._new(np.negative(x._array))
+
def not_equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
@@ -515,6 +572,7 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.not_equal(x1._array, x2._array))
+
def positive(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
@@ -522,9 +580,10 @@ def positive(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in positive')
+ raise TypeError("Only numeric dtypes are allowed in positive")
return Array._new(np.positive(x._array))
+
# Note: the function name is different here
def pow(x1: Array, x2: Array, /) -> Array:
"""
@@ -533,12 +592,13 @@ def pow(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in pow')
+ raise TypeError("Only floating-point dtypes are allowed in pow")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.power(x1._array, x2._array))
+
def remainder(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
@@ -546,12 +606,13 @@ def remainder(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in remainder')
+ raise TypeError("Only numeric dtypes are allowed in remainder")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.remainder(x1._array, x2._array))
+
def round(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
@@ -559,9 +620,10 @@ def round(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in round')
+ raise TypeError("Only numeric dtypes are allowed in round")
return Array._new(np.round(x._array))
+
def sign(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
@@ -569,9 +631,10 @@ def sign(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in sign')
+ raise TypeError("Only numeric dtypes are allowed in sign")
return Array._new(np.sign(x._array))
+
def sin(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
@@ -579,9 +642,10 @@ def sin(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in sin')
+ raise TypeError("Only floating-point dtypes are allowed in sin")
return Array._new(np.sin(x._array))
+
def sinh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
@@ -589,9 +653,10 @@ def sinh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in sinh')
+ raise TypeError("Only floating-point dtypes are allowed in sinh")
return Array._new(np.sinh(x._array))
+
def square(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
@@ -599,9 +664,10 @@ def square(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in square')
+ raise TypeError("Only numeric dtypes are allowed in square")
return Array._new(np.square(x._array))
+
def sqrt(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
@@ -609,9 +675,10 @@ def sqrt(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in sqrt')
+ raise TypeError("Only floating-point dtypes are allowed in sqrt")
return Array._new(np.sqrt(x._array))
+
def subtract(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
@@ -619,12 +686,13 @@ def subtract(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in subtract')
+ raise TypeError("Only numeric dtypes are allowed in subtract")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.subtract(x1._array, x2._array))
+
def tan(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
@@ -632,9 +700,10 @@ def tan(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in tan')
+ raise TypeError("Only floating-point dtypes are allowed in tan")
return Array._new(np.tan(x._array))
+
def tanh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
@@ -642,9 +711,10 @@ def tanh(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
- raise TypeError('Only floating-point dtypes are allowed in tanh')
+ raise TypeError("Only floating-point dtypes are allowed in tanh")
return Array._new(np.tanh(x._array))
+
def trunc(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.
@@ -652,7 +722,7 @@ def trunc(x: Array, /) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in trunc')
+ raise TypeError("Only numeric dtypes are allowed in trunc")
if x.dtype in _integer_dtypes:
# Note: The return dtype of trunc is the same as the input
return x
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
index f13f9c541..089081725 100644
--- a/numpy/array_api/_linear_algebra_functions.py
+++ b/numpy/array_api/_linear_algebra_functions.py
@@ -17,6 +17,7 @@ import numpy as np
# """
# return np.einsum()
+
def matmul(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
@@ -26,23 +27,31 @@ def matmul(x1: Array, x2: Array, /) -> Array:
# Note: the restriction to numeric dtypes only is different from
# np.matmul.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in matmul')
+ raise TypeError("Only numeric dtypes are allowed in matmul")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
return Array._new(np.matmul(x1._array, x2._array))
+
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
-def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
+def tensordot(
+ x1: Array,
+ x2: Array,
+ /,
+ *,
+ axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
+) -> Array:
# Note: the restriction to numeric dtypes only is different from
# np.tensordot.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError('Only numeric dtypes are allowed in tensordot')
+ raise TypeError("Only numeric dtypes are allowed in tensordot")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
+
def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
@@ -51,6 +60,7 @@ def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array:
"""
return Array._new(np.transpose(x._array, axes=axes))
+
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array:
if axis is None:
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index 33f5d5a28..c11866261 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -8,7 +8,9 @@ from typing import List, Optional, Tuple, Union
import numpy as np
# Note: the function name is different here
-def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
+def concat(
+ arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -20,6 +22,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
arrays = tuple(a._array for a in arrays)
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
+
def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
@@ -28,6 +31,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
return Array._new(np.expand_dims(x._array, axis))
+
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
@@ -36,6 +40,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
"""
return Array._new(np.flip(x._array, axis=axis))
+
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
@@ -44,7 +49,14 @@ def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
return Array._new(np.reshape(x._array, shape))
-def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
+
+def roll(
+ x: Array,
+ /,
+ shift: Union[int, Tuple[int, ...]],
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
@@ -52,6 +64,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio
"""
return Array._new(np.roll(x._array, shift, axis=axis))
+
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
@@ -60,6 +73,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
"""
return Array._new(np.squeeze(x._array, axis=axis))
+
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py
index 9dcc76b2d..3dcef61c3 100644
--- a/numpy/array_api/_searching_functions.py
+++ b/numpy/array_api/_searching_functions.py
@@ -7,6 +7,7 @@ from typing import Optional, Tuple
import numpy as np
+
def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`.
@@ -15,6 +16,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
"""
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)))
+
def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`.
@@ -23,6 +25,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
"""
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
+
def nonzero(x: Array, /) -> Tuple[Array, ...]:
"""
Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.
@@ -31,6 +34,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
"""
return tuple(Array._new(i) for i in np.nonzero(x._array))
+
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index acd59f597..357f238f5 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -6,14 +6,26 @@ from typing import Tuple, Union
import numpy as np
-def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False) -> Union[Array, Tuple[Array, ...]]:
+
+def unique(
+ x: Array,
+ /,
+ *,
+ return_counts: bool = False,
+ return_index: bool = False,
+ return_inverse: bool = False,
+) -> Union[Array, Tuple[Array, ...]]:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
- res = np.unique(x._array, return_counts=return_counts,
- return_index=return_index, return_inverse=return_inverse)
+ res = np.unique(
+ x._array,
+ return_counts=return_counts,
+ return_index=return_index,
+ return_inverse=return_inverse,
+ )
if isinstance(res, tuple):
return tuple(Array._new(i) for i in res)
return Array._new(res)
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py
index a125e0718..9cd49786c 100644
--- a/numpy/array_api/_sorting_functions.py
+++ b/numpy/array_api/_sorting_functions.py
@@ -4,27 +4,33 @@ from ._array_object import Array
import numpy as np
-def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array:
+
+def argsort(
+ x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.argsort <numpy.argsort>`.
See its docstring for more information.
"""
# Note: this keyword argument is different, and the default is different.
- kind = 'stable' if stable else 'quicksort'
+ kind = "stable" if stable else "quicksort"
res = np.argsort(x._array, axis=axis, kind=kind)
if descending:
res = np.flip(res, axis=axis)
return Array._new(res)
-def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array:
+
+def sort(
+ x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.sort <numpy.sort>`.
See its docstring for more information.
"""
# Note: this keyword argument is different, and the default is different.
- kind = 'stable' if stable else 'quicksort'
+ kind = "stable" if stable else "quicksort"
res = np.sort(x._array, axis=axis, kind=kind)
if descending:
res = np.flip(res, axis=axis)
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index a606203bc..63790b447 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -6,25 +6,76 @@ from typing import Optional, Tuple, Union
import numpy as np
-def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def max(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
-def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def mean(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
-def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def min(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
-def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def prod(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
-def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array:
+
+def std(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+) -> Array:
# Note: the keyword argument correction is different here
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
-def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def sum(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
-def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array:
+
+def var(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+) -> Array:
# Note: the keyword argument correction is different here
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 4ff718205..d530a91ae 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,21 +6,39 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
-__all__ = ['Array', 'Device', 'Dtype', 'SupportsDLPack',
- 'SupportsBufferProtocol', 'PyCapsule']
+__all__ = [
+ "Array",
+ "Device",
+ "Dtype",
+ "SupportsDLPack",
+ "SupportsBufferProtocol",
+ "PyCapsule",
+]
from typing import Any, Sequence, Type, Union
-from . import (Array, int8, int16, int32, int64, uint8, uint16, uint32,
- uint64, float32, float64)
+from . import (
+ Array,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+)
# This should really be recursive, but that isn't supported yet. See the
# similar comment in numpy/typing/_array_like.py
NestedSequence = Sequence[Sequence[Any]]
Device = Any
-Dtype = Type[Union[[int8, int16, int32, int64, uint8, uint16,
- uint32, uint64, float32, float64]]]
+Dtype = Type[
+ Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]]
+]
SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any
diff --git a/numpy/array_api/_utility_functions.py b/numpy/array_api/_utility_functions.py
index f243bfe68..5ecb4bd9f 100644
--- a/numpy/array_api/_utility_functions.py
+++ b/numpy/array_api/_utility_functions.py
@@ -6,7 +6,14 @@ from typing import Optional, Tuple, Union
import numpy as np
-def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def all(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.all <numpy.all>`.
@@ -14,7 +21,14 @@ def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
"""
return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims)))
-def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
+
+def any(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.any <numpy.any>`.
diff --git a/numpy/array_api/setup.py b/numpy/array_api/setup.py
index da2350c8f..c8bc29102 100644
--- a/numpy/array_api/setup.py
+++ b/numpy/array_api/setup.py
@@ -1,10 +1,12 @@
-def configuration(parent_package='', top_path=None):
+def configuration(parent_package="", top_path=None):
from numpy.distutils.misc_util import Configuration
- config = Configuration('array_api', parent_package, top_path)
- config.add_subpackage('tests')
+
+ config = Configuration("array_api", parent_package, top_path)
+ config.add_subpackage("tests")
return config
-if __name__ == '__main__':
+if __name__ == "__main__":
from numpy.distutils.core import setup
+
setup(configuration=configuration)
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 22078bbee..088e09b9f 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -2,9 +2,20 @@ from numpy.testing import assert_raises
import numpy as np
from .. import ones, asarray, result_type
-from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
- _integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes, int8, int16, int32, int64, uint64)
+from .._dtypes import (
+ _all_dtypes,
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _numeric_dtypes,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint64,
+)
+
def test_validate_index():
# The indexing tests in the official array API test suite test that the
@@ -61,28 +72,29 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])
+
def test_operators():
# For every operator, we test that it works for the required type
# combinations and raises TypeError otherwise
- binary_op_dtypes ={
- '__add__': 'numeric',
- '__and__': 'integer_or_boolean',
- '__eq__': 'all',
- '__floordiv__': 'numeric',
- '__ge__': 'numeric',
- '__gt__': 'numeric',
- '__le__': 'numeric',
- '__lshift__': 'integer',
- '__lt__': 'numeric',
- '__mod__': 'numeric',
- '__mul__': 'numeric',
- '__ne__': 'all',
- '__or__': 'integer_or_boolean',
- '__pow__': 'floating',
- '__rshift__': 'integer',
- '__sub__': 'numeric',
- '__truediv__': 'floating',
- '__xor__': 'integer_or_boolean',
+ binary_op_dtypes = {
+ "__add__": "numeric",
+ "__and__": "integer_or_boolean",
+ "__eq__": "all",
+ "__floordiv__": "numeric",
+ "__ge__": "numeric",
+ "__gt__": "numeric",
+ "__le__": "numeric",
+ "__lshift__": "integer",
+ "__lt__": "numeric",
+ "__mod__": "numeric",
+ "__mul__": "numeric",
+ "__ne__": "all",
+ "__or__": "integer_or_boolean",
+ "__pow__": "floating",
+ "__rshift__": "integer",
+ "__sub__": "numeric",
+ "__truediv__": "floating",
+ "__xor__": "integer_or_boolean",
}
# Recompute each time because of in-place ops
@@ -92,15 +104,15 @@ def test_operators():
for d in _boolean_dtypes:
yield asarray(False, dtype=d)
for d in _floating_dtypes:
- yield asarray(1., dtype=d)
+ yield asarray(1.0, dtype=d)
for op, dtypes in binary_op_dtypes.items():
ops = [op]
- if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']:
- rop = '__r' + op[2:]
- iop = '__i' + op[2:]
+ if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
+ rop = "__r" + op[2:]
+ iop = "__i" + op[2:]
ops += [rop, iop]
- for s in [1, 1., False]:
+ for s in [1, 1.0, False]:
for _op in ops:
for a in _array_vals():
# Test array op scalar. From the spec, the following combinations
@@ -149,7 +161,10 @@ def test_operators():
):
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure in-place operators only promote to the same dtype as the left operand.
- elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype:
+ elif (
+ _op.startswith("__i")
+ and result_type(x.dtype, y.dtype) != x.dtype
+ ):
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure only those dtypes that are required for every operator are allowed.
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
@@ -165,17 +180,20 @@ def test_operators():
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y))
- unary_op_dtypes ={
- '__abs__': 'numeric',
- '__invert__': 'integer_or_boolean',
- '__neg__': 'numeric',
- '__pos__': 'numeric',
+ unary_op_dtypes = {
+ "__abs__": "numeric",
+ "__invert__": "integer_or_boolean",
+ "__neg__": "numeric",
+ "__pos__": "numeric",
}
for op, dtypes in unary_op_dtypes.items():
for a in _array_vals():
- if (dtypes == "numeric" and a.dtype in _numeric_dtypes
- or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
- ):
+ if (
+ dtypes == "numeric"
+ and a.dtype in _numeric_dtypes
+ or dtypes == "integer_or_boolean"
+ and a.dtype in _integer_or_boolean_dtypes
+ ):
# Only test for no error
getattr(a, op)()
else:
@@ -192,8 +210,8 @@ def test_operators():
yield ones((4, 4), dtype=d)
# Scalars always error
- for _op in ['__matmul__', '__rmatmul__', '__imatmul__']:
- for s in [1, 1., False]:
+ for _op in ["__matmul__", "__rmatmul__", "__imatmul__"]:
+ for s in [1, 1.0, False]:
for a in _matmul_array_vals():
if (type(s) in [float, int] and a.dtype in _floating_dtypes
or type(s) == int and a.dtype in _integer_dtypes):
@@ -235,16 +253,17 @@ def test_operators():
else:
x.__imatmul__(y)
+
def test_python_scalar_construtors():
a = asarray(False)
b = asarray(0)
- c = asarray(0.)
+ c = asarray(0.0)
assert bool(a) == bool(b) == bool(c) == False
assert int(a) == int(b) == int(c) == 0
- assert float(a) == float(b) == float(c) == 0.
+ assert float(a) == float(b) == float(c) == 0.0
# bool/int/float should only be allowed on 0-D arrays.
assert_raises(TypeError, lambda: bool(asarray([False])))
assert_raises(TypeError, lambda: int(asarray([0])))
- assert_raises(TypeError, lambda: float(asarray([0.])))
+ assert_raises(TypeError, lambda: float(asarray([0.0])))
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index 654f1d9b3..3cb8865cd 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -2,26 +2,53 @@ from numpy.testing import assert_raises
import numpy as np
from .. import all
-from .._creation_functions import (asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like)
+from .._creation_functions import (
+ asarray,
+ arange,
+ empty,
+ empty_like,
+ eye,
+ from_dlpack,
+ full,
+ full_like,
+ linspace,
+ meshgrid,
+ ones,
+ ones_like,
+ zeros,
+ zeros_like,
+)
from .._array_object import Array
-from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
- _integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes, int8, int16, int32, int64, uint64)
+from .._dtypes import (
+ _all_dtypes,
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _numeric_dtypes,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint64,
+)
+
def test_asarray_errors():
# Test various protections against incorrect usage
assert_raises(TypeError, lambda: Array([1]))
- assert_raises(TypeError, lambda: asarray(['a']))
- assert_raises(ValueError, lambda: asarray([1.], dtype=np.float16))
+ assert_raises(TypeError, lambda: asarray(["a"]))
+ assert_raises(ValueError, lambda: asarray([1.0], dtype=np.float16))
assert_raises(OverflowError, lambda: asarray(2**100))
# Preferably this would be OverflowError
# assert_raises(OverflowError, lambda: asarray([2**100]))
assert_raises(TypeError, lambda: asarray([2**100]))
- asarray([1], device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: asarray([1], device='gpu'))
+ asarray([1], device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: asarray([1], device="gpu"))
assert_raises(ValueError, lambda: asarray([1], dtype=int))
- assert_raises(ValueError, lambda: asarray([1], dtype='i'))
+ assert_raises(ValueError, lambda: asarray([1], dtype="i"))
+
def test_asarray_copy():
a = asarray([1])
@@ -36,68 +63,79 @@ def test_asarray_copy():
# assert all(b[0] == 0)
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
+
def test_arange_errors():
- arange(1, device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: arange(1, device='gpu'))
+ arange(1, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: arange(1, device="gpu"))
assert_raises(ValueError, lambda: arange(1, dtype=int))
- assert_raises(ValueError, lambda: arange(1, dtype='i'))
+ assert_raises(ValueError, lambda: arange(1, dtype="i"))
+
def test_empty_errors():
- empty((1,), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: empty((1,), device='gpu'))
+ empty((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: empty((1,), device="gpu"))
assert_raises(ValueError, lambda: empty((1,), dtype=int))
- assert_raises(ValueError, lambda: empty((1,), dtype='i'))
+ assert_raises(ValueError, lambda: empty((1,), dtype="i"))
+
def test_empty_like_errors():
- empty_like(asarray(1), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: empty_like(asarray(1), device='gpu'))
+ empty_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int))
- assert_raises(ValueError, lambda: empty_like(asarray(1), dtype='i'))
+ assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i"))
+
def test_eye_errors():
- eye(1, device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: eye(1, device='gpu'))
+ eye(1, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: eye(1, device="gpu"))
assert_raises(ValueError, lambda: eye(1, dtype=int))
- assert_raises(ValueError, lambda: eye(1, dtype='i'))
+ assert_raises(ValueError, lambda: eye(1, dtype="i"))
+
def test_full_errors():
- full((1,), 0, device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: full((1,), 0, device='gpu'))
+ full((1,), 0, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
- assert_raises(ValueError, lambda: full((1,), 0, dtype='i'))
+ assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))
+
def test_full_like_errors():
- full_like(asarray(1), 0, device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: full_like(asarray(1), 0, device='gpu'))
+ full_like(asarray(1), 0, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
- assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype='i'))
+ assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))
+
def test_linspace_errors():
- linspace(0, 1, 10, device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: linspace(0, 1, 10, device='gpu'))
+ linspace(0, 1, 10, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu"))
assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float))
- assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype='f'))
+ assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f"))
+
def test_ones_errors():
- ones((1,), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: ones((1,), device='gpu'))
+ ones((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: ones((1,), device="gpu"))
assert_raises(ValueError, lambda: ones((1,), dtype=int))
- assert_raises(ValueError, lambda: ones((1,), dtype='i'))
+ assert_raises(ValueError, lambda: ones((1,), dtype="i"))
+
def test_ones_like_errors():
- ones_like(asarray(1), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: ones_like(asarray(1), device='gpu'))
+ ones_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int))
- assert_raises(ValueError, lambda: ones_like(asarray(1), dtype='i'))
+ assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i"))
+
def test_zeros_errors():
- zeros((1,), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: zeros((1,), device='gpu'))
+ zeros((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: zeros((1,), device="gpu"))
assert_raises(ValueError, lambda: zeros((1,), dtype=int))
- assert_raises(ValueError, lambda: zeros((1,), dtype='i'))
+ assert_raises(ValueError, lambda: zeros((1,), dtype="i"))
+
def test_zeros_like_errors():
- zeros_like(asarray(1), device='cpu') # Doesn't error
- assert_raises(ValueError, lambda: zeros_like(asarray(1), device='gpu'))
+ zeros_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
- assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype='i'))
+ assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
index ec76cb7a7..a9274aec9 100644
--- a/numpy/array_api/tests/test_elementwise_functions.py
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -4,74 +4,80 @@ from numpy.testing import assert_raises
from .. import asarray, _elementwise_functions
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
-from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes,
- _integer_dtypes)
+from .._dtypes import (
+ _dtype_categories,
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+)
+
def nargs(func):
return len(getfullargspec(func).args)
+
def test_function_types():
# Test that every function accepts only the required input types. We only
# test the negative cases here (error). The positive cases are tested in
# the array API test suite.
elementwise_function_input_types = {
- 'abs': 'numeric',
- 'acos': 'floating-point',
- 'acosh': 'floating-point',
- 'add': 'numeric',
- 'asin': 'floating-point',
- 'asinh': 'floating-point',
- 'atan': 'floating-point',
- 'atan2': 'floating-point',
- 'atanh': 'floating-point',
- 'bitwise_and': 'integer or boolean',
- 'bitwise_invert': 'integer or boolean',
- 'bitwise_left_shift': 'integer',
- 'bitwise_or': 'integer or boolean',
- 'bitwise_right_shift': 'integer',
- 'bitwise_xor': 'integer or boolean',
- 'ceil': 'numeric',
- 'cos': 'floating-point',
- 'cosh': 'floating-point',
- 'divide': 'floating-point',
- 'equal': 'all',
- 'exp': 'floating-point',
- 'expm1': 'floating-point',
- 'floor': 'numeric',
- 'floor_divide': 'numeric',
- 'greater': 'numeric',
- 'greater_equal': 'numeric',
- 'isfinite': 'numeric',
- 'isinf': 'numeric',
- 'isnan': 'numeric',
- 'less': 'numeric',
- 'less_equal': 'numeric',
- 'log': 'floating-point',
- 'logaddexp': 'floating-point',
- 'log10': 'floating-point',
- 'log1p': 'floating-point',
- 'log2': 'floating-point',
- 'logical_and': 'boolean',
- 'logical_not': 'boolean',
- 'logical_or': 'boolean',
- 'logical_xor': 'boolean',
- 'multiply': 'numeric',
- 'negative': 'numeric',
- 'not_equal': 'all',
- 'positive': 'numeric',
- 'pow': 'floating-point',
- 'remainder': 'numeric',
- 'round': 'numeric',
- 'sign': 'numeric',
- 'sin': 'floating-point',
- 'sinh': 'floating-point',
- 'sqrt': 'floating-point',
- 'square': 'numeric',
- 'subtract': 'numeric',
- 'tan': 'floating-point',
- 'tanh': 'floating-point',
- 'trunc': 'numeric',
+ "abs": "numeric",
+ "acos": "floating-point",
+ "acosh": "floating-point",
+ "add": "numeric",
+ "asin": "floating-point",
+ "asinh": "floating-point",
+ "atan": "floating-point",
+ "atan2": "floating-point",
+ "atanh": "floating-point",
+ "bitwise_and": "integer or boolean",
+ "bitwise_invert": "integer or boolean",
+ "bitwise_left_shift": "integer",
+ "bitwise_or": "integer or boolean",
+ "bitwise_right_shift": "integer",
+ "bitwise_xor": "integer or boolean",
+ "ceil": "numeric",
+ "cos": "floating-point",
+ "cosh": "floating-point",
+ "divide": "floating-point",
+ "equal": "all",
+ "exp": "floating-point",
+ "expm1": "floating-point",
+ "floor": "numeric",
+ "floor_divide": "numeric",
+ "greater": "numeric",
+ "greater_equal": "numeric",
+ "isfinite": "numeric",
+ "isinf": "numeric",
+ "isnan": "numeric",
+ "less": "numeric",
+ "less_equal": "numeric",
+ "log": "floating-point",
+ "logaddexp": "floating-point",
+ "log10": "floating-point",
+ "log1p": "floating-point",
+ "log2": "floating-point",
+ "logical_and": "boolean",
+ "logical_not": "boolean",
+ "logical_or": "boolean",
+ "logical_xor": "boolean",
+ "multiply": "numeric",
+ "negative": "numeric",
+ "not_equal": "all",
+ "positive": "numeric",
+ "pow": "floating-point",
+ "remainder": "numeric",
+ "round": "numeric",
+ "sign": "numeric",
+ "sin": "floating-point",
+ "sinh": "floating-point",
+ "sqrt": "floating-point",
+ "square": "numeric",
+ "subtract": "numeric",
+ "tan": "floating-point",
+ "tanh": "floating-point",
+ "trunc": "numeric",
}
def _array_vals():
@@ -80,7 +86,7 @@ def test_function_types():
for d in _boolean_dtypes:
yield asarray(False, dtype=d)
for d in _floating_dtypes:
- yield asarray(1., dtype=d)
+ yield asarray(1.0, dtype=d)
for x in _array_vals():
for func_name, types in elementwise_function_input_types.items():
@@ -94,7 +100,12 @@ def test_function_types():
if x.dtype not in dtypes:
assert_raises(TypeError, lambda: func(x))
+
def test_bitwise_shift_error():
# bitwise shift functions should raise when the second argument is negative
- assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])))
- assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])))
+ assert_raises(
+ ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))
+ )
+ assert_raises(
+ ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
+ )