diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-08-06 18:22:00 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-08-06 18:23:04 -0600 |
commit | 8f7d00ed447174d9398af3365709222b529c1cad (patch) | |
tree | 9de0a3a757a8c8a7393787ee1449e087c284d6e1 /numpy/array_api | |
parent | 21923a5fa71bfadf7dee0bb5b110cc2a5719eaac (diff) | |
download | numpy-8f7d00ed447174d9398af3365709222b529c1cad.tar.gz |
Run (selective) black on the array_api submodule
I've omitted a few changes from black that messed up the readability of some
complicated if statements that were organized logically line-by-line, and some
changes that use unnecessary operator spacing.
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/__init__.py | 247 | ||||
-rw-r--r-- | numpy/array_api/_array_object.py | 205 | ||||
-rw-r--r-- | numpy/array_api/_creation_functions.py | 158 | ||||
-rw-r--r-- | numpy/array_api/_data_type_functions.py | 16 | ||||
-rw-r--r-- | numpy/array_api/_dtypes.py | 75 | ||||
-rw-r--r-- | numpy/array_api/_elementwise_functions.py | 194 | ||||
-rw-r--r-- | numpy/array_api/_linear_algebra_functions.py | 16 | ||||
-rw-r--r-- | numpy/array_api/_manipulation_functions.py | 18 | ||||
-rw-r--r-- | numpy/array_api/_searching_functions.py | 4 | ||||
-rw-r--r-- | numpy/array_api/_set_functions.py | 18 | ||||
-rw-r--r-- | numpy/array_api/_sorting_functions.py | 14 | ||||
-rw-r--r-- | numpy/array_api/_statistical_functions.py | 65 | ||||
-rw-r--r-- | numpy/array_api/_typing.py | 30 | ||||
-rw-r--r-- | numpy/array_api/_utility_functions.py | 18 | ||||
-rw-r--r-- | numpy/array_api/setup.py | 10 | ||||
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 101 | ||||
-rw-r--r-- | numpy/array_api/tests/test_creation_functions.py | 122 | ||||
-rw-r--r-- | numpy/array_api/tests/test_elementwise_functions.py | 133 |
18 files changed, 1054 insertions, 390 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 4dc931732..53c1f3850 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -119,36 +119,221 @@ Still TODO in this module are: """ import sys + # numpy.array_api is 3.8+ because it makes extensive use of positional-only # arguments. if sys.version_info < (3, 8): raise ImportError("The numpy.array_api submodule requires Python 3.8 or greater.") import warnings -warnings.warn("The numpy.array_api submodule is still experimental. See NEP 47.", - stacklevel=2) + +warnings.warn( + "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2 +) __all__ = [] from ._constants import e, inf, nan, pi -__all__ += ['e', 'inf', 'nan', 'pi'] - -from ._creation_functions import asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like - -__all__ += ['asarray', 'arange', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like'] - -from ._data_type_functions import broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type - -__all__ += ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type'] - -from ._dtypes import int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, bool - -__all__ += ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float32', 'float64', 'bool'] - -from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc - -__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] +__all__ += ["e", "inf", "nan", "pi"] + +from ._creation_functions import ( + asarray, + arange, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + zeros, + zeros_like, +) + +__all__ += [ + "asarray", + "arange", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "zeros", + "zeros_like", +] + +from ._data_type_functions import ( + broadcast_arrays, + broadcast_to, + can_cast, + finfo, + iinfo, + result_type, +) + +__all__ += [ + "broadcast_arrays", + "broadcast_to", + "can_cast", + "finfo", + "iinfo", + "result_type", +] + +from ._dtypes import ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + bool, +) + +__all__ += [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "bool", +] + +from ._elementwise_functions import ( + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_left_shift, + bitwise_invert, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + remainder, + round, + sign, + sin, + sinh, + square, + sqrt, + subtract, + tan, + tanh, + trunc, +) + +__all__ += [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", +] # einsum is not yet implemented in the array API spec. @@ -157,28 +342,36 @@ __all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'at from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot -__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot'] +__all__ += ["matmul", "tensordot", "transpose", "vecdot"] -from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack +from ._manipulation_functions import ( + concat, + expand_dims, + flip, + reshape, + roll, + squeeze, + stack, +) -__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] +__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where -__all__ += ['argmax', 'argmin', 'nonzero', 'where'] +__all__ += ["argmax", "argmin", "nonzero", "where"] from ._set_functions import unique -__all__ += ['unique'] +__all__ += ["unique"] from ._sorting_functions import argsort, sort -__all__ += ['argsort', 'sort'] +__all__ += ["argsort", "sort"] from ._statistical_functions import max, mean, min, prod, std, sum, var -__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] +__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any -__all__ += ['all', 'any'] +__all__ += ["all", "any"] diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 00f50eade..0f511a577 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -18,11 +18,19 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _floating_dtypes, - _numeric_dtypes, _result_type, _dtype_categories) +from ._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _floating_dtypes, + _numeric_dtypes, + _result_type, + _dtype_categories, +) from typing import TYPE_CHECKING, Optional, Tuple, Union + if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -30,6 +38,7 @@ import numpy as np from numpy import array_api + class Array: """ n-d array object for the array API namespace. @@ -45,6 +54,7 @@ class Array: functions, such as asarray(). """ + # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod @@ -64,13 +74,17 @@ class Array: # Convert the array scalar to a 0-D array x = np.asarray(x) if x.dtype not in _all_dtypes: - raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'") + raise TypeError( + f"The array_api namespace does not support the dtype '{x.dtype}'" + ) obj._array = x return obj # Prevent Array() from working def __new__(cls, *args, **kwargs): - raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") + raise TypeError( + "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." + ) # These functions are not required by the spec, but are implemented for # the sake of usability. @@ -79,7 +93,7 @@ class Array: """ Performs the operation __str__. """ - return self._array.__str__().replace('array', 'Array') + return self._array.__str__().replace("array", "Array") def __repr__(self: Array, /) -> str: """ @@ -103,12 +117,12 @@ class Array: """ if self.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") else: return NotImplemented @@ -116,7 +130,7 @@ class Array: # to promote in the spec (even if the NumPy array operator would # promote them). res_dtype = _result_type(self.dtype, other.dtype) - if op.startswith('__i'): + if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side # operand. For example, @@ -126,7 +140,9 @@ class Array: # The spec explicitly disallows this. if res_dtype != self.dtype: - raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + raise TypeError( + f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}" + ) return other @@ -142,13 +158,19 @@ class Array: """ if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: - raise TypeError("Python bool scalars can only be promoted with bool arrays") + raise TypeError( + "Python bool scalars can only be promoted with bool arrays" + ) elif isinstance(scalar, int): if self.dtype in _boolean_dtypes: - raise TypeError("Python int scalars cannot be promoted with bool arrays") + raise TypeError( + "Python int scalars cannot be promoted with bool arrays" + ) elif isinstance(scalar, float): if self.dtype not in _floating_dtypes: - raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + raise TypeError( + "Python float scalars can only be promoted with floating-point arrays." + ) else: raise TypeError("'scalar' must be a Python scalar") @@ -253,7 +275,9 @@ class Array: except TypeError: return key if not (-size <= key.start <= max(0, size - 1)): - raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + raise IndexError( + "Slices with out-of-bounds start are not allowed in the array API namespace" + ) if key.stop is not None: try: operator.index(key.stop) @@ -269,12 +293,20 @@ class Array: key = tuple(Array._validate_index(idx, None) for idx in key) for idx in key: - if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if ( + isinstance(idx, np.ndarray) + and idx.dtype in _boolean_dtypes + or isinstance(idx, (bool, np.bool_)) + ): if len(key) == 1: return key - raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + raise IndexError( + "Boolean array indices combined with other indices are not allowed in the array API namespace" + ) if isinstance(idx, tuple): - raise IndexError("Nested tuple indices are not allowed in the array API namespace") + raise IndexError( + "Nested tuple indices are not allowed in the array API namespace" + ) if shape is None: return key @@ -283,7 +315,9 @@ class Array: return key ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + for idx, size in list(zip(key[:ellipsis_i], shape)) + list( + zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) + ): Array._validate_index(idx, (size,)) return key elif isinstance(key, bool): @@ -291,18 +325,24 @@ class Array: elif isinstance(key, Array): if key.dtype in _integer_dtypes: if key.ndim != 0: - raise IndexError("Non-zero dimensional integer array indices are not allowed in the array API namespace") + raise IndexError( + "Non-zero dimensional integer array indices are not allowed in the array API namespace" + ) return key._array elif key is Ellipsis: return key elif key is None: - raise IndexError("newaxis indices are not allowed in the array API namespace") + raise IndexError( + "newaxis indices are not allowed in the array API namespace" + ) try: return operator.index(key) except TypeError: # Note: This also omits boolean arrays that are not already in # Array() form, like a list of booleans. - raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + raise IndexError( + "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" + ) # Everything below this line is required by the spec. @@ -311,7 +351,7 @@ class Array: Performs the operation __abs__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __abs__') + raise TypeError("Only numeric dtypes are allowed in __abs__") res = self._array.__abs__() return self.__class__._new(res) @@ -319,7 +359,7 @@ class Array: """ Performs the operation __add__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__add__') + other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -330,15 +370,17 @@ class Array: """ Performs the operation __and__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) - def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object: - if api_version is not None and not api_version.startswith('2021.'): + def __array_namespace__( + self: Array, /, *, api_version: Optional[str] = None + ) -> object: + if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api @@ -373,7 +415,7 @@ class Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, 'all', '__eq__') + other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -394,7 +436,7 @@ class Array: """ Performs the operation __floordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -405,14 +447,20 @@ class Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + other = self._check_allowed_dtypes(other, "numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) - def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array: + def __getitem__( + self: Array, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + /, + ) -> Array: """ Performs the operation __getitem__. """ @@ -426,7 +474,7 @@ class Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + other = self._check_allowed_dtypes(other, "numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -448,7 +496,7 @@ class Array: Performs the operation __invert__. """ if self.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in __invert__') + raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() return self.__class__._new(res) @@ -456,7 +504,7 @@ class Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__le__') + other = self._check_allowed_dtypes(other, "numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -474,7 +522,7 @@ class Array: """ Performs the operation __lshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -485,7 +533,7 @@ class Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + other = self._check_allowed_dtypes(other, "numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -498,7 +546,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + other = self._check_allowed_dtypes(other, "numeric", "__matmul__") if other is NotImplemented: return other res = self._array.__matmul__(other._array) @@ -508,7 +556,7 @@ class Array: """ Performs the operation __mod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + other = self._check_allowed_dtypes(other, "numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -519,7 +567,7 @@ class Array: """ Performs the operation __mul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -530,7 +578,7 @@ class Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, 'all', '__ne__') + other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -542,7 +590,7 @@ class Array: Performs the operation __neg__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __neg__') + raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() return self.__class__._new(res) @@ -550,7 +598,7 @@ class Array: """ Performs the operation __or__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -562,7 +610,7 @@ class Array: Performs the operation __pos__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __pos__') + raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() return self.__class__._new(res) @@ -574,7 +622,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + other = self._check_allowed_dtypes(other, "floating-point", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -585,14 +633,21 @@ class Array: """ Performs the operation __rshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) - def __setitem__(self, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], value: Union[int, float, bool, Array], /) -> Array: + def __setitem__( + self, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + value: Union[int, float, bool, Array], + /, + ) -> Array: """ Performs the operation __setitem__. """ @@ -605,7 +660,7 @@ class Array: """ Performs the operation __sub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -618,7 +673,7 @@ class Array: """ Performs the operation __truediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -629,7 +684,7 @@ class Array: """ Performs the operation __xor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -640,7 +695,7 @@ class Array: """ Performs the operation __iadd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other self._array.__iadd__(other._array) @@ -650,7 +705,7 @@ class Array: """ Performs the operation __radd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -661,7 +716,7 @@ class Array: """ Performs the operation __iand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other self._array.__iand__(other._array) @@ -671,7 +726,7 @@ class Array: """ Performs the operation __rand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -682,7 +737,7 @@ class Array: """ Performs the operation __ifloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__") if other is NotImplemented: return other self._array.__ifloordiv__(other._array) @@ -692,7 +747,7 @@ class Array: """ Performs the operation __rfloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -703,7 +758,7 @@ class Array: """ Performs the operation __ilshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other self._array.__ilshift__(other._array) @@ -713,7 +768,7 @@ class Array: """ Performs the operation __rlshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -728,7 +783,7 @@ class Array: # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other @@ -748,7 +803,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other res = self._array.__rmatmul__(other._array) @@ -758,7 +813,7 @@ class Array: """ Performs the operation __imod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + other = self._check_allowed_dtypes(other, "numeric", "__imod__") if other is NotImplemented: return other self._array.__imod__(other._array) @@ -768,7 +823,7 @@ class Array: """ Performs the operation __rmod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + other = self._check_allowed_dtypes(other, "numeric", "__rmod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -779,7 +834,7 @@ class Array: """ Performs the operation __imul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + other = self._check_allowed_dtypes(other, "numeric", "__imul__") if other is NotImplemented: return other self._array.__imul__(other._array) @@ -789,7 +844,7 @@ class Array: """ Performs the operation __rmul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -800,7 +855,7 @@ class Array: """ Performs the operation __ior__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") if other is NotImplemented: return other self._array.__ior__(other._array) @@ -810,7 +865,7 @@ class Array: """ Performs the operation __ror__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -821,7 +876,7 @@ class Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) @@ -833,7 +888,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules @@ -844,7 +899,7 @@ class Array: """ Performs the operation __irshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + other = self._check_allowed_dtypes(other, "integer", "__irshift__") if other is NotImplemented: return other self._array.__irshift__(other._array) @@ -854,7 +909,7 @@ class Array: """ Performs the operation __rrshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -865,7 +920,7 @@ class Array: """ Performs the operation __isub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + other = self._check_allowed_dtypes(other, "numeric", "__isub__") if other is NotImplemented: return other self._array.__isub__(other._array) @@ -875,7 +930,7 @@ class Array: """ Performs the operation __rsub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -886,7 +941,7 @@ class Array: """ Performs the operation __itruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") if other is NotImplemented: return other self._array.__itruediv__(other._array) @@ -896,7 +951,7 @@ class Array: """ Performs the operation __rtruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -907,7 +962,7 @@ class Array: """ Performs the operation __ixor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") if other is NotImplemented: return other self._array.__ixor__(other._array) @@ -917,7 +972,7 @@ class Array: """ Performs the operation __rxor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -935,7 +990,7 @@ class Array: @property def device(self) -> Device: - return 'cpu' + return "cpu" @property def ndim(self) -> int: diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index acf78056a..e9c01e7e6 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -2,14 +2,22 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Tuple, Union + if TYPE_CHECKING: - from ._typing import (Array, Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol) + from ._typing import ( + Array, + Device, + Dtype, + NestedSequence, + SupportsDLPack, + SupportsBufferProtocol, + ) from collections.abc import Sequence from ._dtypes import _all_dtypes import numpy as np + def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -20,7 +28,23 @@ def _check_valid_dtype(dtype): return raise ValueError("dtype must be one of the supported dtypes") -def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array: + +def asarray( + obj: Union[ + Array, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsDLPack, + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`. @@ -31,7 +55,7 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") if copy is False: # Note: copy=False is not yet implemented in np.asarray @@ -40,14 +64,23 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], if copy is True: return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj - if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63): + if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") res = np.asarray(obj, dtype=dtype) return Array._new(res) -def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`. @@ -56,11 +89,17 @@ def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) -def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def empty( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`. @@ -69,11 +108,14 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.empty(shape, dtype=dtype)) -def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def empty_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`. @@ -82,11 +124,20 @@ def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.empty_like(x._array, dtype=dtype)) -def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: Optional[int] = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`. @@ -95,15 +146,23 @@ def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + def from_dlpack(x: object, /) -> Array: # Note: dlpack support is not yet implemented on Array raise NotImplementedError("DLPack support is not yet implemented") -def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.full <numpy.full>`. @@ -112,7 +171,7 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array @@ -123,7 +182,15 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d raise TypeError("Invalid input to full") return Array._new(res) -def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def full_like( + x: Array, + /, + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`. @@ -132,7 +199,7 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = np.full_like(x._array, fill_value, dtype=dtype) if res.dtype not in _all_dtypes: @@ -141,7 +208,17 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty raise TypeError("Invalid input to full_like") return Array._new(res) -def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array: + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> Array: """ Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`. @@ -150,20 +227,31 @@ def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) -def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]: + +def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]: """ Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`. See its docstring for more information. """ from ._array_object import Array - return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)] -def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + return [ + Array._new(array) + for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) + ] + + +def ones( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`. @@ -172,11 +260,14 @@ def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.ones(shape, dtype=dtype)) -def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def ones_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`. @@ -185,11 +276,17 @@ def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[De from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.ones_like(x._array, dtype=dtype)) -def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def zeros( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`. @@ -198,11 +295,14 @@ def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.zeros(shape, dtype=dtype)) -def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def zeros_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`. @@ -211,6 +311,6 @@ def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 17a00cc6d..e6121a8a4 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -5,12 +5,14 @@ from ._dtypes import _all_dtypes, _result_type from dataclasses import dataclass from typing import TYPE_CHECKING, List, Tuple, Union + if TYPE_CHECKING: from ._typing import Dtype from collections.abc import Sequence import numpy as np + def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`. @@ -18,7 +20,11 @@ def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: See its docstring for more information. """ from ._array_object import Array - return [Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])] + + return [ + Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays]) + ] + def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: """ @@ -27,8 +33,10 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: See its docstring for more information. """ from ._array_object import Array + return Array._new(np.broadcast_to(x._array, shape)) + def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: """ Array API compatible wrapper for :py:func:`np.can_cast <numpy.can_cast>`. @@ -36,10 +44,12 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: See its docstring for more information. """ from ._array_object import Array + if isinstance(from_, Array): from_ = from_._array return np.can_cast(from_, to) + # These are internal objects for the return types of finfo and iinfo, since # the NumPy versions contain extra data that isn't part of the spec. @dataclass @@ -55,12 +65,14 @@ class finfo_object: # smallest_normal: float + @dataclass class iinfo_object: bits: int max: int min: int + def finfo(type: Union[Dtype, Array], /) -> finfo_object: """ Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`. @@ -79,6 +91,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: # float(fi.smallest_normal), ) + def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: """ Array API compatible wrapper for :py:func:`np.iinfo <numpy.iinfo>`. @@ -88,6 +101,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: ii = np.iinfo(type) return iinfo_object(ii.bits, ii.max, ii.min) + def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`. diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index 07be267da..476d619fe 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -2,34 +2,66 @@ import numpy as np # Note: we use dtype objects instead of dtype classes. The spec does not # require any behavior on dtypes other than equality. -int8 = np.dtype('int8') -int16 = np.dtype('int16') -int32 = np.dtype('int32') -int64 = np.dtype('int64') -uint8 = np.dtype('uint8') -uint16 = np.dtype('uint16') -uint32 = np.dtype('uint32') -uint64 = np.dtype('uint64') -float32 = np.dtype('float32') -float64 = np.dtype('float64') +int8 = np.dtype("int8") +int16 = np.dtype("int16") +int32 = np.dtype("int32") +int64 = np.dtype("int64") +uint8 = np.dtype("uint8") +uint16 = np.dtype("uint16") +uint32 = np.dtype("uint32") +uint64 = np.dtype("uint64") +float32 = np.dtype("float32") +float64 = np.dtype("float64") # Note: This name is changed -bool = np.dtype('bool') +bool = np.dtype("bool") -_all_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, bool) +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + bool, +) _boolean_dtypes = (bool,) _floating_dtypes = (float32, float64) _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) _dtype_categories = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer or boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating-point': _floating_dtypes, + "all": _all_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "floating-point": _floating_dtypes, } @@ -104,6 +136,7 @@ _promotion_table = { (bool, bool): bool, } + def _result_type(type1, type2): if (type1, type2) in _promotion_table: return _promotion_table[type1, type2] diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 7833ebe54..4408fe833 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -1,12 +1,18 @@ from __future__ import annotations -from ._dtypes import (_boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, _result_type) +from ._dtypes import ( + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + _result_type, +) from ._array_object import Array import numpy as np + def abs(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`. @@ -14,9 +20,10 @@ def abs(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in abs') + raise TypeError("Only numeric dtypes are allowed in abs") return Array._new(np.abs(x._array)) + # Note: the function name is different here def acos(x: Array, /) -> Array: """ @@ -25,9 +32,10 @@ def acos(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in acos') + raise TypeError("Only floating-point dtypes are allowed in acos") return Array._new(np.arccos(x._array)) + # Note: the function name is different here def acosh(x: Array, /) -> Array: """ @@ -36,9 +44,10 @@ def acosh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in acosh') + raise TypeError("Only floating-point dtypes are allowed in acosh") return Array._new(np.arccosh(x._array)) + def add(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.add <numpy.add>`. @@ -46,12 +55,13 @@ def add(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in add') + raise TypeError("Only numeric dtypes are allowed in add") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.add(x1._array, x2._array)) + # Note: the function name is different here def asin(x: Array, /) -> Array: """ @@ -60,9 +70,10 @@ def asin(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in asin') + raise TypeError("Only floating-point dtypes are allowed in asin") return Array._new(np.arcsin(x._array)) + # Note: the function name is different here def asinh(x: Array, /) -> Array: """ @@ -71,9 +82,10 @@ def asinh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in asinh') + raise TypeError("Only floating-point dtypes are allowed in asinh") return Array._new(np.arcsinh(x._array)) + # Note: the function name is different here def atan(x: Array, /) -> Array: """ @@ -82,9 +94,10 @@ def atan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atan') + raise TypeError("Only floating-point dtypes are allowed in atan") return Array._new(np.arctan(x._array)) + # Note: the function name is different here def atan2(x1: Array, x2: Array, /) -> Array: """ @@ -93,12 +106,13 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atan2') + raise TypeError("Only floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.arctan2(x1._array, x2._array)) + # Note: the function name is different here def atanh(x: Array, /) -> Array: """ @@ -107,22 +121,27 @@ def atanh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atanh') + raise TypeError("Only floating-point dtypes are allowed in atanh") return Array._new(np.arctanh(x._array)) + def bitwise_and(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_and(x1._array, x2._array)) + # Note: the function name is different here def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: """ @@ -131,15 +150,16 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') + raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): - raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0') + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return Array._new(np.left_shift(x1._array, x2._array)) + # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: """ @@ -148,22 +168,27 @@ def bitwise_invert(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") return Array._new(np.invert(x._array)) + def bitwise_or(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_or(x1._array, x2._array)) + # Note: the function name is different here def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: """ @@ -172,28 +197,33 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') + raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): - raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0') + raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") return Array._new(np.right_shift(x1._array, x2._array)) + def bitwise_xor(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_xor(x1._array, x2._array)) + def ceil(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`. @@ -201,12 +231,13 @@ def ceil(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in ceil') + raise TypeError("Only numeric dtypes are allowed in ceil") if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x return Array._new(np.ceil(x._array)) + def cos(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`. @@ -214,9 +245,10 @@ def cos(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in cos') + raise TypeError("Only floating-point dtypes are allowed in cos") return Array._new(np.cos(x._array)) + def cosh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`. @@ -224,9 +256,10 @@ def cosh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in cosh') + raise TypeError("Only floating-point dtypes are allowed in cosh") return Array._new(np.cosh(x._array)) + def divide(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`. @@ -234,12 +267,13 @@ def divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in divide') + raise TypeError("Only floating-point dtypes are allowed in divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.divide(x1._array, x2._array)) + def equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`. @@ -251,6 +285,7 @@ def equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) + def exp(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`. @@ -258,9 +293,10 @@ def exp(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in exp') + raise TypeError("Only floating-point dtypes are allowed in exp") return Array._new(np.exp(x._array)) + def expm1(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`. @@ -268,9 +304,10 @@ def expm1(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in expm1') + raise TypeError("Only floating-point dtypes are allowed in expm1") return Array._new(np.expm1(x._array)) + def floor(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`. @@ -278,12 +315,13 @@ def floor(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in floor') + raise TypeError("Only numeric dtypes are allowed in floor") if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x return Array._new(np.floor(x._array)) + def floor_divide(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`. @@ -291,12 +329,13 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in floor_divide') + raise TypeError("Only numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.floor_divide(x1._array, x2._array)) + def greater(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`. @@ -304,12 +343,13 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in greater') + raise TypeError("Only numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) + def greater_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`. @@ -317,12 +357,13 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in greater_equal') + raise TypeError("Only numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) + def isfinite(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`. @@ -330,9 +371,10 @@ def isfinite(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isfinite') + raise TypeError("Only numeric dtypes are allowed in isfinite") return Array._new(np.isfinite(x._array)) + def isinf(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`. @@ -340,9 +382,10 @@ def isinf(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isinf') + raise TypeError("Only numeric dtypes are allowed in isinf") return Array._new(np.isinf(x._array)) + def isnan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`. @@ -350,9 +393,10 @@ def isnan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isnan') + raise TypeError("Only numeric dtypes are allowed in isnan") return Array._new(np.isnan(x._array)) + def less(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.less <numpy.less>`. @@ -360,12 +404,13 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in less') + raise TypeError("Only numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) + def less_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`. @@ -373,12 +418,13 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in less_equal') + raise TypeError("Only numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) + def log(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log <numpy.log>`. @@ -386,9 +432,10 @@ def log(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log') + raise TypeError("Only floating-point dtypes are allowed in log") return Array._new(np.log(x._array)) + def log1p(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`. @@ -396,9 +443,10 @@ def log1p(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log1p') + raise TypeError("Only floating-point dtypes are allowed in log1p") return Array._new(np.log1p(x._array)) + def log2(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`. @@ -406,9 +454,10 @@ def log2(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log2') + raise TypeError("Only floating-point dtypes are allowed in log2") return Array._new(np.log2(x._array)) + def log10(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`. @@ -416,9 +465,10 @@ def log10(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log10') + raise TypeError("Only floating-point dtypes are allowed in log10") return Array._new(np.log10(x._array)) + def logaddexp(x1: Array, x2: Array) -> Array: """ Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`. @@ -426,12 +476,13 @@ def logaddexp(x1: Array, x2: Array) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in logaddexp') + raise TypeError("Only floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logaddexp(x1._array, x2._array)) + def logical_and(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`. @@ -439,12 +490,13 @@ def logical_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_and') + raise TypeError("Only boolean dtypes are allowed in logical_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_and(x1._array, x2._array)) + def logical_not(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`. @@ -452,9 +504,10 @@ def logical_not(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_not') + raise TypeError("Only boolean dtypes are allowed in logical_not") return Array._new(np.logical_not(x._array)) + def logical_or(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`. @@ -462,12 +515,13 @@ def logical_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_or') + raise TypeError("Only boolean dtypes are allowed in logical_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_or(x1._array, x2._array)) + def logical_xor(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`. @@ -475,12 +529,13 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_xor') + raise TypeError("Only boolean dtypes are allowed in logical_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) + def multiply(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`. @@ -488,12 +543,13 @@ def multiply(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in multiply') + raise TypeError("Only numeric dtypes are allowed in multiply") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.multiply(x1._array, x2._array)) + def negative(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`. @@ -501,9 +557,10 @@ def negative(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in negative') + raise TypeError("Only numeric dtypes are allowed in negative") return Array._new(np.negative(x._array)) + def not_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`. @@ -515,6 +572,7 @@ def not_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) + def positive(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`. @@ -522,9 +580,10 @@ def positive(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in positive') + raise TypeError("Only numeric dtypes are allowed in positive") return Array._new(np.positive(x._array)) + # Note: the function name is different here def pow(x1: Array, x2: Array, /) -> Array: """ @@ -533,12 +592,13 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in pow') + raise TypeError("Only floating-point dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.power(x1._array, x2._array)) + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`. @@ -546,12 +606,13 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in remainder') + raise TypeError("Only numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.remainder(x1._array, x2._array)) + def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round <numpy.round>`. @@ -559,9 +620,10 @@ def round(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in round') + raise TypeError("Only numeric dtypes are allowed in round") return Array._new(np.round(x._array)) + def sign(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`. @@ -569,9 +631,10 @@ def sign(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in sign') + raise TypeError("Only numeric dtypes are allowed in sign") return Array._new(np.sign(x._array)) + def sin(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`. @@ -579,9 +642,10 @@ def sin(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sin') + raise TypeError("Only floating-point dtypes are allowed in sin") return Array._new(np.sin(x._array)) + def sinh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`. @@ -589,9 +653,10 @@ def sinh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sinh') + raise TypeError("Only floating-point dtypes are allowed in sinh") return Array._new(np.sinh(x._array)) + def square(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.square <numpy.square>`. @@ -599,9 +664,10 @@ def square(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in square') + raise TypeError("Only numeric dtypes are allowed in square") return Array._new(np.square(x._array)) + def sqrt(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`. @@ -609,9 +675,10 @@ def sqrt(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sqrt') + raise TypeError("Only floating-point dtypes are allowed in sqrt") return Array._new(np.sqrt(x._array)) + def subtract(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`. @@ -619,12 +686,13 @@ def subtract(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in subtract') + raise TypeError("Only numeric dtypes are allowed in subtract") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.subtract(x1._array, x2._array)) + def tan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`. @@ -632,9 +700,10 @@ def tan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in tan') + raise TypeError("Only floating-point dtypes are allowed in tan") return Array._new(np.tan(x._array)) + def tanh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`. @@ -642,9 +711,10 @@ def tanh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in tanh') + raise TypeError("Only floating-point dtypes are allowed in tanh") return Array._new(np.tanh(x._array)) + def trunc(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`. @@ -652,7 +722,7 @@ def trunc(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in trunc') + raise TypeError("Only numeric dtypes are allowed in trunc") if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py index f13f9c541..089081725 100644 --- a/numpy/array_api/_linear_algebra_functions.py +++ b/numpy/array_api/_linear_algebra_functions.py @@ -17,6 +17,7 @@ import numpy as np # """ # return np.einsum() + def matmul(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. @@ -26,23 +27,31 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: the restriction to numeric dtypes only is different from # np.matmul. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in matmul') + raise TypeError("Only numeric dtypes are allowed in matmul") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) return Array._new(np.matmul(x1._array, x2._array)) + # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, +) -> Array: # Note: the restriction to numeric dtypes only is different from # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in tensordot') + raise TypeError("Only numeric dtypes are allowed in tensordot") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`. @@ -51,6 +60,7 @@ def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: """ return Array._new(np.transpose(x._array, axes=axes)) + # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: if axis is None: diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index 33f5d5a28..c11866261 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -8,7 +8,9 @@ from typing import List, Optional, Tuple, Union import numpy as np # Note: the function name is different here -def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: +def concat( + arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0 +) -> Array: """ Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`. @@ -20,6 +22,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i arrays = tuple(a._array for a in arrays) return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype)) + def expand_dims(x: Array, /, *, axis: int) -> Array: """ Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`. @@ -28,6 +31,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: """ return Array._new(np.expand_dims(x._array, axis)) + def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`. @@ -36,6 +40,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.flip(x._array, axis=axis)) + def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`. @@ -44,7 +49,14 @@ def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ return Array._new(np.reshape(x._array, shape)) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + +def roll( + x: Array, + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`. @@ -52,6 +64,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio """ return Array._new(np.roll(x._array, shift, axis=axis)) + def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`. @@ -60,6 +73,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ return Array._new(np.squeeze(x._array, axis=axis)) + def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`. diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index 9dcc76b2d..3dcef61c3 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import numpy as np + def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`. @@ -15,6 +16,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`. @@ -23,6 +25,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + def nonzero(x: Array, /) -> Tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`. @@ -31,6 +34,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: """ return tuple(Array._new(i) for i in np.nonzero(x._array)) + def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where <numpy.where>`. diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index acd59f597..357f238f5 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -6,14 +6,26 @@ from typing import Tuple, Union import numpy as np -def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False) -> Union[Array, Tuple[Array, ...]]: + +def unique( + x: Array, + /, + *, + return_counts: bool = False, + return_index: bool = False, + return_inverse: bool = False, +) -> Union[Array, Tuple[Array, ...]]: """ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. See its docstring for more information. """ - res = np.unique(x._array, return_counts=return_counts, - return_index=return_index, return_inverse=return_inverse) + res = np.unique( + x._array, + return_counts=return_counts, + return_index=return_index, + return_inverse=return_inverse, + ) if isinstance(res, tuple): return tuple(Array._new(i) for i in res) return Array._new(res) diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index a125e0718..9cd49786c 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -4,27 +4,33 @@ from ._array_object import Array import numpy as np -def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + +def argsort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: """ Array API compatible wrapper for :py:func:`np.argsort <numpy.argsort>`. See its docstring for more information. """ # Note: this keyword argument is different, and the default is different. - kind = 'stable' if stable else 'quicksort' + kind = "stable" if stable else "quicksort" res = np.argsort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) return Array._new(res) -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + +def sort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: """ Array API compatible wrapper for :py:func:`np.sort <numpy.sort>`. See its docstring for more information. """ # Note: this keyword argument is different, and the default is different. - kind = 'stable' if stable else 'quicksort' + kind = "stable" if stable else "quicksort" res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index a606203bc..63790b447 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -6,25 +6,76 @@ from typing import Optional, Tuple, Union import numpy as np -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def max( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) -def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def mean( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def min( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) -def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def prod( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) -def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + +def std( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: # Note: the keyword argument correction is different here return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) -def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def sum( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) -def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + +def var( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: # Note: the keyword argument correction is different here return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 4ff718205..d530a91ae 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,21 +6,39 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ -__all__ = ['Array', 'Device', 'Dtype', 'SupportsDLPack', - 'SupportsBufferProtocol', 'PyCapsule'] +__all__ = [ + "Array", + "Device", + "Dtype", + "SupportsDLPack", + "SupportsBufferProtocol", + "PyCapsule", +] from typing import Any, Sequence, Type, Union -from . import (Array, int8, int16, int32, int64, uint8, uint16, uint32, - uint64, float32, float64) +from . import ( + Array, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) # This should really be recursive, but that isn't supported yet. See the # similar comment in numpy/typing/_array_like.py NestedSequence = Sequence[Sequence[Any]] Device = Any -Dtype = Type[Union[[int8, int16, int32, int64, uint8, uint16, - uint32, uint64, float32, float64]]] +Dtype = Type[ + Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]] +] SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any diff --git a/numpy/array_api/_utility_functions.py b/numpy/array_api/_utility_functions.py index f243bfe68..5ecb4bd9f 100644 --- a/numpy/array_api/_utility_functions.py +++ b/numpy/array_api/_utility_functions.py @@ -6,7 +6,14 @@ from typing import Optional, Tuple, Union import numpy as np -def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def all( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: """ Array API compatible wrapper for :py:func:`np.all <numpy.all>`. @@ -14,7 +21,14 @@ def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep """ return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) -def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def any( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: """ Array API compatible wrapper for :py:func:`np.any <numpy.any>`. diff --git a/numpy/array_api/setup.py b/numpy/array_api/setup.py index da2350c8f..c8bc29102 100644 --- a/numpy/array_api/setup.py +++ b/numpy/array_api/setup.py @@ -1,10 +1,12 @@ -def configuration(parent_package='', top_path=None): +def configuration(parent_package="", top_path=None): from numpy.distutils.misc_util import Configuration - config = Configuration('array_api', parent_package, top_path) - config.add_subpackage('tests') + + config = Configuration("array_api", parent_package, top_path) + config.add_subpackage("tests") return config -if __name__ == '__main__': +if __name__ == "__main__": from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 22078bbee..088e09b9f 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,9 +2,20 @@ from numpy.testing import assert_raises import numpy as np from .. import ones, asarray, result_type -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, int8, int16, int32, int64, uint64) +from .._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + int8, + int16, + int32, + int64, + uint64, +) + def test_validate_index(): # The indexing tests in the official array API test suite test that the @@ -61,28 +72,29 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[None, ...]) assert_raises(IndexError, lambda: a[..., None]) + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise - binary_op_dtypes ={ - '__add__': 'numeric', - '__and__': 'integer_or_boolean', - '__eq__': 'all', - '__floordiv__': 'numeric', - '__ge__': 'numeric', - '__gt__': 'numeric', - '__le__': 'numeric', - '__lshift__': 'integer', - '__lt__': 'numeric', - '__mod__': 'numeric', - '__mul__': 'numeric', - '__ne__': 'all', - '__or__': 'integer_or_boolean', - '__pow__': 'floating', - '__rshift__': 'integer', - '__sub__': 'numeric', - '__truediv__': 'floating', - '__xor__': 'integer_or_boolean', + binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer_or_boolean", + "__eq__": "all", + "__floordiv__": "numeric", + "__ge__": "numeric", + "__gt__": "numeric", + "__le__": "numeric", + "__lshift__": "integer", + "__lt__": "numeric", + "__mod__": "numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer_or_boolean", + "__pow__": "floating", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating", + "__xor__": "integer_or_boolean", } # Recompute each time because of in-place ops @@ -92,15 +104,15 @@ def test_operators(): for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: - yield asarray(1., dtype=d) + yield asarray(1.0, dtype=d) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']: - rop = '__r' + op[2:] - iop = '__i' + op[2:] + if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: + rop = "__r" + op[2:] + iop = "__i" + op[2:] ops += [rop, iop] - for s in [1, 1., False]: + for s in [1, 1.0, False]: for _op in ops: for a in _array_vals(): # Test array op scalar. From the spec, the following combinations @@ -149,7 +161,10 @@ def test_operators(): ): assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure in-place operators only promote to the same dtype as the left operand. - elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype: + elif ( + _op.startswith("__i") + and result_type(x.dtype, y.dtype) != x.dtype + ): assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure only those dtypes that are required for every operator are allowed. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes @@ -165,17 +180,20 @@ def test_operators(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) - unary_op_dtypes ={ - '__abs__': 'numeric', - '__invert__': 'integer_or_boolean', - '__neg__': 'numeric', - '__pos__': 'numeric', + unary_op_dtypes = { + "__abs__": "numeric", + "__invert__": "integer_or_boolean", + "__neg__": "numeric", + "__pos__": "numeric", } for op, dtypes in unary_op_dtypes.items(): for a in _array_vals(): - if (dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes - ): + if ( + dtypes == "numeric" + and a.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" + and a.dtype in _integer_or_boolean_dtypes + ): # Only test for no error getattr(a, op)() else: @@ -192,8 +210,8 @@ def test_operators(): yield ones((4, 4), dtype=d) # Scalars always error - for _op in ['__matmul__', '__rmatmul__', '__imatmul__']: - for s in [1, 1., False]: + for _op in ["__matmul__", "__rmatmul__", "__imatmul__"]: + for s in [1, 1.0, False]: for a in _matmul_array_vals(): if (type(s) in [float, int] and a.dtype in _floating_dtypes or type(s) == int and a.dtype in _integer_dtypes): @@ -235,16 +253,17 @@ def test_operators(): else: x.__imatmul__(y) + def test_python_scalar_construtors(): a = asarray(False) b = asarray(0) - c = asarray(0.) + c = asarray(0.0) assert bool(a) == bool(b) == bool(c) == False assert int(a) == int(b) == int(c) == 0 - assert float(a) == float(b) == float(c) == 0. + assert float(a) == float(b) == float(c) == 0.0 # bool/int/float should only be allowed on 0-D arrays. assert_raises(TypeError, lambda: bool(asarray([False]))) assert_raises(TypeError, lambda: int(asarray([0]))) - assert_raises(TypeError, lambda: float(asarray([0.]))) + assert_raises(TypeError, lambda: float(asarray([0.0]))) diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 654f1d9b3..3cb8865cd 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -2,26 +2,53 @@ from numpy.testing import assert_raises import numpy as np from .. import all -from .._creation_functions import (asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like) +from .._creation_functions import ( + asarray, + arange, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + zeros, + zeros_like, +) from .._array_object import Array -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, int8, int16, int32, int64, uint64) +from .._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + int8, + int16, + int32, + int64, + uint64, +) + def test_asarray_errors(): # Test various protections against incorrect usage assert_raises(TypeError, lambda: Array([1])) - assert_raises(TypeError, lambda: asarray(['a'])) - assert_raises(ValueError, lambda: asarray([1.], dtype=np.float16)) + assert_raises(TypeError, lambda: asarray(["a"])) + assert_raises(ValueError, lambda: asarray([1.0], dtype=np.float16)) assert_raises(OverflowError, lambda: asarray(2**100)) # Preferably this would be OverflowError # assert_raises(OverflowError, lambda: asarray([2**100])) assert_raises(TypeError, lambda: asarray([2**100])) - asarray([1], device='cpu') # Doesn't error - assert_raises(ValueError, lambda: asarray([1], device='gpu')) + asarray([1], device="cpu") # Doesn't error + assert_raises(ValueError, lambda: asarray([1], device="gpu")) assert_raises(ValueError, lambda: asarray([1], dtype=int)) - assert_raises(ValueError, lambda: asarray([1], dtype='i')) + assert_raises(ValueError, lambda: asarray([1], dtype="i")) + def test_asarray_copy(): a = asarray([1]) @@ -36,68 +63,79 @@ def test_asarray_copy(): # assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) + def test_arange_errors(): - arange(1, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: arange(1, device='gpu')) + arange(1, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: arange(1, device="gpu")) assert_raises(ValueError, lambda: arange(1, dtype=int)) - assert_raises(ValueError, lambda: arange(1, dtype='i')) + assert_raises(ValueError, lambda: arange(1, dtype="i")) + def test_empty_errors(): - empty((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: empty((1,), device='gpu')) + empty((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: empty((1,), device="gpu")) assert_raises(ValueError, lambda: empty((1,), dtype=int)) - assert_raises(ValueError, lambda: empty((1,), dtype='i')) + assert_raises(ValueError, lambda: empty((1,), dtype="i")) + def test_empty_like_errors(): - empty_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: empty_like(asarray(1), device='gpu')) + empty_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: empty_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i")) + def test_eye_errors(): - eye(1, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: eye(1, device='gpu')) + eye(1, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: eye(1, device="gpu")) assert_raises(ValueError, lambda: eye(1, dtype=int)) - assert_raises(ValueError, lambda: eye(1, dtype='i')) + assert_raises(ValueError, lambda: eye(1, dtype="i")) + def test_full_errors(): - full((1,), 0, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: full((1,), 0, device='gpu')) + full((1,), 0, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: full((1,), 0, device="gpu")) assert_raises(ValueError, lambda: full((1,), 0, dtype=int)) - assert_raises(ValueError, lambda: full((1,), 0, dtype='i')) + assert_raises(ValueError, lambda: full((1,), 0, dtype="i")) + def test_full_like_errors(): - full_like(asarray(1), 0, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: full_like(asarray(1), 0, device='gpu')) + full_like(asarray(1), 0, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu")) assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int)) - assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype='i')) + assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i")) + def test_linspace_errors(): - linspace(0, 1, 10, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: linspace(0, 1, 10, device='gpu')) + linspace(0, 1, 10, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu")) assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float)) - assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype='f')) + assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f")) + def test_ones_errors(): - ones((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: ones((1,), device='gpu')) + ones((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: ones((1,), device="gpu")) assert_raises(ValueError, lambda: ones((1,), dtype=int)) - assert_raises(ValueError, lambda: ones((1,), dtype='i')) + assert_raises(ValueError, lambda: ones((1,), dtype="i")) + def test_ones_like_errors(): - ones_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: ones_like(asarray(1), device='gpu')) + ones_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: ones_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i")) + def test_zeros_errors(): - zeros((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: zeros((1,), device='gpu')) + zeros((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: zeros((1,), device="gpu")) assert_raises(ValueError, lambda: zeros((1,), dtype=int)) - assert_raises(ValueError, lambda: zeros((1,), dtype='i')) + assert_raises(ValueError, lambda: zeros((1,), dtype="i")) + def test_zeros_like_errors(): - zeros_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: zeros_like(asarray(1), device='gpu')) + zeros_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i")) diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index ec76cb7a7..a9274aec9 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -4,74 +4,80 @@ from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, - _integer_dtypes) +from .._dtypes import ( + _dtype_categories, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, +) + def nargs(func): return len(getfullargspec(func).args) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. elementwise_function_input_types = { - 'abs': 'numeric', - 'acos': 'floating-point', - 'acosh': 'floating-point', - 'add': 'numeric', - 'asin': 'floating-point', - 'asinh': 'floating-point', - 'atan': 'floating-point', - 'atan2': 'floating-point', - 'atanh': 'floating-point', - 'bitwise_and': 'integer or boolean', - 'bitwise_invert': 'integer or boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer or boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer or boolean', - 'ceil': 'numeric', - 'cos': 'floating-point', - 'cosh': 'floating-point', - 'divide': 'floating-point', - 'equal': 'all', - 'exp': 'floating-point', - 'expm1': 'floating-point', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating-point', - 'logaddexp': 'floating-point', - 'log10': 'floating-point', - 'log1p': 'floating-point', - 'log2': 'floating-point', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'all', - 'positive': 'numeric', - 'pow': 'floating-point', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating-point', - 'sinh': 'floating-point', - 'sqrt': 'floating-point', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating-point', - 'tanh': 'floating-point', - 'trunc': 'numeric', + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "numeric", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "numeric", + "floor_divide": "numeric", + "greater": "numeric", + "greater_equal": "numeric", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "numeric", + "less_equal": "numeric", + "log": "floating-point", + "logaddexp": "floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "floating-point", + "remainder": "numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "numeric", } def _array_vals(): @@ -80,7 +86,7 @@ def test_function_types(): for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: - yield asarray(1., dtype=d) + yield asarray(1.0, dtype=d) for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): @@ -94,7 +100,12 @@ def test_function_types(): if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) + def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) - assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) + assert_raises( + ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + ) + assert_raises( + ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) + ) |