diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-08-06 18:22:00 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-08-06 18:23:04 -0600 |
commit | 8f7d00ed447174d9398af3365709222b529c1cad (patch) | |
tree | 9de0a3a757a8c8a7393787ee1449e087c284d6e1 /numpy/array_api/_array_object.py | |
parent | 21923a5fa71bfadf7dee0bb5b110cc2a5719eaac (diff) | |
download | numpy-8f7d00ed447174d9398af3365709222b529c1cad.tar.gz |
Run (selective) black on the array_api submodule
I've omitted a few changes from black that messed up the readability of some
complicated if statements that were organized logically line-by-line, and some
changes that use unnecessary operator spacing.
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r-- | numpy/array_api/_array_object.py | 205 |
1 files changed, 130 insertions, 75 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 00f50eade..0f511a577 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -18,11 +18,19 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _floating_dtypes, - _numeric_dtypes, _result_type, _dtype_categories) +from ._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _floating_dtypes, + _numeric_dtypes, + _result_type, + _dtype_categories, +) from typing import TYPE_CHECKING, Optional, Tuple, Union + if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -30,6 +38,7 @@ import numpy as np from numpy import array_api + class Array: """ n-d array object for the array API namespace. @@ -45,6 +54,7 @@ class Array: functions, such as asarray(). """ + # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod @@ -64,13 +74,17 @@ class Array: # Convert the array scalar to a 0-D array x = np.asarray(x) if x.dtype not in _all_dtypes: - raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'") + raise TypeError( + f"The array_api namespace does not support the dtype '{x.dtype}'" + ) obj._array = x return obj # Prevent Array() from working def __new__(cls, *args, **kwargs): - raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") + raise TypeError( + "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." + ) # These functions are not required by the spec, but are implemented for # the sake of usability. @@ -79,7 +93,7 @@ class Array: """ Performs the operation __str__. """ - return self._array.__str__().replace('array', 'Array') + return self._array.__str__().replace("array", "Array") def __repr__(self: Array, /) -> str: """ @@ -103,12 +117,12 @@ class Array: """ if self.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") else: return NotImplemented @@ -116,7 +130,7 @@ class Array: # to promote in the spec (even if the NumPy array operator would # promote them). res_dtype = _result_type(self.dtype, other.dtype) - if op.startswith('__i'): + if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side # operand. For example, @@ -126,7 +140,9 @@ class Array: # The spec explicitly disallows this. if res_dtype != self.dtype: - raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + raise TypeError( + f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}" + ) return other @@ -142,13 +158,19 @@ class Array: """ if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: - raise TypeError("Python bool scalars can only be promoted with bool arrays") + raise TypeError( + "Python bool scalars can only be promoted with bool arrays" + ) elif isinstance(scalar, int): if self.dtype in _boolean_dtypes: - raise TypeError("Python int scalars cannot be promoted with bool arrays") + raise TypeError( + "Python int scalars cannot be promoted with bool arrays" + ) elif isinstance(scalar, float): if self.dtype not in _floating_dtypes: - raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + raise TypeError( + "Python float scalars can only be promoted with floating-point arrays." + ) else: raise TypeError("'scalar' must be a Python scalar") @@ -253,7 +275,9 @@ class Array: except TypeError: return key if not (-size <= key.start <= max(0, size - 1)): - raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + raise IndexError( + "Slices with out-of-bounds start are not allowed in the array API namespace" + ) if key.stop is not None: try: operator.index(key.stop) @@ -269,12 +293,20 @@ class Array: key = tuple(Array._validate_index(idx, None) for idx in key) for idx in key: - if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if ( + isinstance(idx, np.ndarray) + and idx.dtype in _boolean_dtypes + or isinstance(idx, (bool, np.bool_)) + ): if len(key) == 1: return key - raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + raise IndexError( + "Boolean array indices combined with other indices are not allowed in the array API namespace" + ) if isinstance(idx, tuple): - raise IndexError("Nested tuple indices are not allowed in the array API namespace") + raise IndexError( + "Nested tuple indices are not allowed in the array API namespace" + ) if shape is None: return key @@ -283,7 +315,9 @@ class Array: return key ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + for idx, size in list(zip(key[:ellipsis_i], shape)) + list( + zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) + ): Array._validate_index(idx, (size,)) return key elif isinstance(key, bool): @@ -291,18 +325,24 @@ class Array: elif isinstance(key, Array): if key.dtype in _integer_dtypes: if key.ndim != 0: - raise IndexError("Non-zero dimensional integer array indices are not allowed in the array API namespace") + raise IndexError( + "Non-zero dimensional integer array indices are not allowed in the array API namespace" + ) return key._array elif key is Ellipsis: return key elif key is None: - raise IndexError("newaxis indices are not allowed in the array API namespace") + raise IndexError( + "newaxis indices are not allowed in the array API namespace" + ) try: return operator.index(key) except TypeError: # Note: This also omits boolean arrays that are not already in # Array() form, like a list of booleans. - raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + raise IndexError( + "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" + ) # Everything below this line is required by the spec. @@ -311,7 +351,7 @@ class Array: Performs the operation __abs__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __abs__') + raise TypeError("Only numeric dtypes are allowed in __abs__") res = self._array.__abs__() return self.__class__._new(res) @@ -319,7 +359,7 @@ class Array: """ Performs the operation __add__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__add__') + other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -330,15 +370,17 @@ class Array: """ Performs the operation __and__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) - def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object: - if api_version is not None and not api_version.startswith('2021.'): + def __array_namespace__( + self: Array, /, *, api_version: Optional[str] = None + ) -> object: + if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api @@ -373,7 +415,7 @@ class Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, 'all', '__eq__') + other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -394,7 +436,7 @@ class Array: """ Performs the operation __floordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -405,14 +447,20 @@ class Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + other = self._check_allowed_dtypes(other, "numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) - def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array: + def __getitem__( + self: Array, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + /, + ) -> Array: """ Performs the operation __getitem__. """ @@ -426,7 +474,7 @@ class Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + other = self._check_allowed_dtypes(other, "numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -448,7 +496,7 @@ class Array: Performs the operation __invert__. """ if self.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in __invert__') + raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() return self.__class__._new(res) @@ -456,7 +504,7 @@ class Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__le__') + other = self._check_allowed_dtypes(other, "numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -474,7 +522,7 @@ class Array: """ Performs the operation __lshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -485,7 +533,7 @@ class Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + other = self._check_allowed_dtypes(other, "numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -498,7 +546,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + other = self._check_allowed_dtypes(other, "numeric", "__matmul__") if other is NotImplemented: return other res = self._array.__matmul__(other._array) @@ -508,7 +556,7 @@ class Array: """ Performs the operation __mod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + other = self._check_allowed_dtypes(other, "numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -519,7 +567,7 @@ class Array: """ Performs the operation __mul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -530,7 +578,7 @@ class Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, 'all', '__ne__') + other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -542,7 +590,7 @@ class Array: Performs the operation __neg__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __neg__') + raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() return self.__class__._new(res) @@ -550,7 +598,7 @@ class Array: """ Performs the operation __or__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -562,7 +610,7 @@ class Array: Performs the operation __pos__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __pos__') + raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() return self.__class__._new(res) @@ -574,7 +622,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + other = self._check_allowed_dtypes(other, "floating-point", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -585,14 +633,21 @@ class Array: """ Performs the operation __rshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) - def __setitem__(self, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], value: Union[int, float, bool, Array], /) -> Array: + def __setitem__( + self, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + value: Union[int, float, bool, Array], + /, + ) -> Array: """ Performs the operation __setitem__. """ @@ -605,7 +660,7 @@ class Array: """ Performs the operation __sub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -618,7 +673,7 @@ class Array: """ Performs the operation __truediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -629,7 +684,7 @@ class Array: """ Performs the operation __xor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -640,7 +695,7 @@ class Array: """ Performs the operation __iadd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other self._array.__iadd__(other._array) @@ -650,7 +705,7 @@ class Array: """ Performs the operation __radd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -661,7 +716,7 @@ class Array: """ Performs the operation __iand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other self._array.__iand__(other._array) @@ -671,7 +726,7 @@ class Array: """ Performs the operation __rand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -682,7 +737,7 @@ class Array: """ Performs the operation __ifloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__") if other is NotImplemented: return other self._array.__ifloordiv__(other._array) @@ -692,7 +747,7 @@ class Array: """ Performs the operation __rfloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -703,7 +758,7 @@ class Array: """ Performs the operation __ilshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other self._array.__ilshift__(other._array) @@ -713,7 +768,7 @@ class Array: """ Performs the operation __rlshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -728,7 +783,7 @@ class Array: # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other @@ -748,7 +803,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other res = self._array.__rmatmul__(other._array) @@ -758,7 +813,7 @@ class Array: """ Performs the operation __imod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + other = self._check_allowed_dtypes(other, "numeric", "__imod__") if other is NotImplemented: return other self._array.__imod__(other._array) @@ -768,7 +823,7 @@ class Array: """ Performs the operation __rmod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + other = self._check_allowed_dtypes(other, "numeric", "__rmod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -779,7 +834,7 @@ class Array: """ Performs the operation __imul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + other = self._check_allowed_dtypes(other, "numeric", "__imul__") if other is NotImplemented: return other self._array.__imul__(other._array) @@ -789,7 +844,7 @@ class Array: """ Performs the operation __rmul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -800,7 +855,7 @@ class Array: """ Performs the operation __ior__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") if other is NotImplemented: return other self._array.__ior__(other._array) @@ -810,7 +865,7 @@ class Array: """ Performs the operation __ror__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -821,7 +876,7 @@ class Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) @@ -833,7 +888,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules @@ -844,7 +899,7 @@ class Array: """ Performs the operation __irshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + other = self._check_allowed_dtypes(other, "integer", "__irshift__") if other is NotImplemented: return other self._array.__irshift__(other._array) @@ -854,7 +909,7 @@ class Array: """ Performs the operation __rrshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -865,7 +920,7 @@ class Array: """ Performs the operation __isub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + other = self._check_allowed_dtypes(other, "numeric", "__isub__") if other is NotImplemented: return other self._array.__isub__(other._array) @@ -875,7 +930,7 @@ class Array: """ Performs the operation __rsub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -886,7 +941,7 @@ class Array: """ Performs the operation __itruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") if other is NotImplemented: return other self._array.__itruediv__(other._array) @@ -896,7 +951,7 @@ class Array: """ Performs the operation __rtruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -907,7 +962,7 @@ class Array: """ Performs the operation __ixor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") if other is NotImplemented: return other self._array.__ixor__(other._array) @@ -917,7 +972,7 @@ class Array: """ Performs the operation __rxor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -935,7 +990,7 @@ class Array: @property def device(self) -> Device: - return 'cpu' + return "cpu" @property def ndim(self) -> int: |