From 6e57d829cb6628610e163524f203245b247a2839 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 16:47:05 -0600 Subject: Rename numpy._array_api to numpy.array_api Instead of the leading underscore, the experimentalness of the module will be indicated by omitting a warning on import. That we, we do not have to change the API from underscore to no underscore when the module is no longer experimental. --- numpy/array_api/__init__.py | 171 ++++ numpy/array_api/_array_object.py | 983 +++++++++++++++++++++ numpy/array_api/_constants.py | 6 + numpy/array_api/_creation_functions.py | 216 +++++ numpy/array_api/_data_type_functions.py | 117 +++ numpy/array_api/_dtypes.py | 100 +++ numpy/array_api/_elementwise_functions.py | 659 ++++++++++++++ numpy/array_api/_linear_algebra_functions.py | 58 ++ numpy/array_api/_manipulation_functions.py | 71 ++ numpy/array_api/_searching_functions.py | 44 + numpy/array_api/_set_functions.py | 15 + numpy/array_api/_sorting_functions.py | 31 + numpy/array_api/_statistical_functions.py | 30 + numpy/array_api/_typing.py | 26 + numpy/array_api/_utility_functions.py | 23 + numpy/array_api/tests/__init__.py | 7 + numpy/array_api/tests/test_array_object.py | 250 ++++++ numpy/array_api/tests/test_creation_functions.py | 103 +++ .../array_api/tests/test_elementwise_functions.py | 110 +++ 19 files changed, 3020 insertions(+) create mode 100644 numpy/array_api/__init__.py create mode 100644 numpy/array_api/_array_object.py create mode 100644 numpy/array_api/_constants.py create mode 100644 numpy/array_api/_creation_functions.py create mode 100644 numpy/array_api/_data_type_functions.py create mode 100644 numpy/array_api/_dtypes.py create mode 100644 numpy/array_api/_elementwise_functions.py create mode 100644 numpy/array_api/_linear_algebra_functions.py create mode 100644 numpy/array_api/_manipulation_functions.py create mode 100644 numpy/array_api/_searching_functions.py create mode 100644 numpy/array_api/_set_functions.py create mode 100644 numpy/array_api/_sorting_functions.py create mode 100644 numpy/array_api/_statistical_functions.py create mode 100644 numpy/array_api/_typing.py create mode 100644 numpy/array_api/_utility_functions.py create mode 100644 numpy/array_api/tests/__init__.py create mode 100644 numpy/array_api/tests/test_array_object.py create mode 100644 numpy/array_api/tests/test_creation_functions.py create mode 100644 numpy/array_api/tests/test_elementwise_functions.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py new file mode 100644 index 000000000..4650e3db8 --- /dev/null +++ b/numpy/array_api/__init__.py @@ -0,0 +1,171 @@ +""" +A NumPy sub-namespace that conforms to the Python array API standard. + +This submodule accompanies NEP 47, which proposes its inclusion in NumPy. + +This is a proof-of-concept namespace that wraps the corresponding NumPy +functions to give a conforming implementation of the Python array API standard +(https://data-apis.github.io/array-api/latest/). The standard is currently in +an RFC phase and comments on it are both welcome and encouraged. Comments +should be made either at https://github.com/data-apis/array-api or at +https://github.com/data-apis/consortium-feedback/discussions. + +NumPy already follows the proposed spec for the most part, so this module +serves mostly as a thin wrapper around it. However, NumPy also implements a +lot of behavior that is not included in the spec, so this serves as a +restricted subset of the API. Only those functions that are part of the spec +are included in this namespace, and all functions are given with the exact +signature given in the spec, including the use of position-only arguments, and +omitting any extra keyword arguments implemented by NumPy but not part of the +spec. The behavior of some functions is also modified from the NumPy behavior +to conform to the standard. Note that the underlying array object itself is +wrapped in a wrapper Array() class, but is otherwise unchanged. This submodule +is implemented in pure Python with no C extensions. + +The array API spec is designed as a "minimal API subset" and explicitly allows +libraries to include behaviors not specified by it. But users of this module +that intend to write portable code should be aware that only those behaviors +that are listed in the spec are guaranteed to be implemented across libraries. +Consequently, the NumPy implementation was chosen to be both conforming and +minimal, so that users can use this implementation of the array API namespace +and be sure that behaviors that it defines will be available in conforming +namespaces from other libraries. + +A few notes about the current state of this submodule: + +- There is a test suite that tests modules against the array API standard at + https://github.com/data-apis/array-api-tests. The test suite is still a work + in progress, but the existing tests pass on this module, with a few + exceptions: + + - Device support is not yet implemented in NumPy + (https://data-apis.github.io/array-api/latest/design_topics/device_support.html). + As a result, the `device` attribute of the array object is missing, and + array creation functions that take the `device` keyword argument will fail + with NotImplementedError. + + - DLPack support (see https://github.com/data-apis/array-api/pull/106) is + not included here, as it requires a full implementation in NumPy proper + first. + + - The linear algebra extension in the spec will be added in a future pull +request. + + The test suite is not yet complete, and even the tests that exist are not + guaranteed to give a comprehensive coverage of the spec. Therefore, those + reviewing this submodule should refer to the standard documents themselves. + +- There is a custom array object, numpy.array_api.Array, which is returned + by all functions in this module. All functions in the array API namespace + implicitly assume that they will only receive this object as input. The only + way to create instances of this object is to use one of the array creation + functions. It does not have a public constructor on the object itself. The + object is a small wrapper Python class around numpy.ndarray. The main + purpose of it is to restrict the namespace of the array object to only those + dtypes and only those methods that are required by the spec, as well as to + limit/change certain behavior that differs in the spec. In particular: + + - The array API namespace does not have scalar objects, only 0-d arrays. + Operations in on Array that would create a scalar in NumPy create a 0-d + array. + + - Indexing: Only a subset of indices supported by NumPy are required by the + spec. The Array object restricts indexing to only allow those types of + indices that are required by the spec. See the docstring of the + numpy.array_api.Array._validate_indices helper function for more + information. + + - Type promotion: Some type promotion rules are different in the spec. In + particular, the spec does not have any value-based casting. The + Array._promote_scalar method promotes Python scalars to arrays, + disallowing cross-type promotions like int -> float64 that are not allowed + in the spec. Array._normalize_two_args works around some type promotion + quirks in NumPy, particularly, value-based casting that occurs when one + argument of an operation is a 0-d array. + +- All functions include type annotations, corresponding to those given in the + spec (see _typing.py for definitions of some custom types). These do not + currently fully pass mypy due to some limitations in mypy. + +- Dtype objects are just the NumPy dtype objects, e.g., float64 = + np.dtype('float64'). The spec does not require any behavior on these dtype + objects other than that they be accessible by name and be comparable by + equality, but it was considered too much extra complexity to create custom + objects to represent dtypes. + +- The wrapper functions in this module do not do any type checking for things + that would be impossible without leaving the array_api namespace. For + example, since the array API dtype objects are just the NumPy dtype objects, + one could pass in a non-spec NumPy dtype into a function. + +- All places where the implementations in this submodule are known to deviate + from their corresponding functions in NumPy are marked with "# Note" + comments. Reviewers should make note of these comments. + +Still TODO in this module are: + +- Device support and DLPack support are not yet implemented. These require + support in NumPy itself first. + +- The a non-default value for the `copy` keyword argument is not yet + implemented on asarray. This requires support in numpy.asarray() first. + +- Some functions are not yet fully tested in the array API test suite, and may + require updates that are not yet known until the tests are written. + +""" + +__all__ = [] + +from ._constants import e, inf, nan, pi + +__all__ += ['e', 'inf', 'nan', 'pi'] + +from ._creation_functions import asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like + +__all__ += ['asarray', 'arange', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like'] + +from ._data_type_functions import broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type + +__all__ += ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type'] + +from ._dtypes import int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, bool + +__all__ += ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float32', 'float64', 'bool'] + +from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc + +__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] + +# einsum is not yet implemented in the array API spec. + +# from ._linear_algebra_functions import einsum +# __all__ += ['einsum'] + +from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot + +__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot'] + +from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack + +__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] + +from ._searching_functions import argmax, argmin, nonzero, where + +__all__ += ['argmax', 'argmin', 'nonzero', 'where'] + +from ._set_functions import unique + +__all__ += ['unique'] + +from ._sorting_functions import argsort, sort + +__all__ += ['argsort', 'sort'] + +from ._statistical_functions import max, mean, min, prod, std, sum, var + +__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] + +from ._utility_functions import all, any + +__all__ += ['all', 'any'] diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py new file mode 100644 index 000000000..24957fde6 --- /dev/null +++ b/numpy/array_api/_array_object.py @@ -0,0 +1,983 @@ +""" +Wrapper class around the ndarray object for the array API standard. + +The array API standard defines some behaviors differently than ndarray, in +particular, type promotion rules are different (the standard has no +value-based casting). The standard also specifies a more limited subset of +array methods and functionalities than are implemented on ndarray. Since the +goal of the array_api namespace is to be a minimal implementation of the array +API standard, we need to define a separate wrapper class for the array_api +namespace. + +The standard compliant class is only a wrapper class. It is *not* a subclass +of ndarray. +""" + +from __future__ import annotations + +import operator +from enum import IntEnum +from ._creation_functions import asarray +from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, + _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) + +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +if TYPE_CHECKING: + from ._typing import PyCapsule, Device, Dtype + +import numpy as np + +class Array: + """ + n-d array object for the array API namespace. + + See the docstring of :py:obj:`np.ndarray ` for more + information. + + This is a wrapper around numpy.ndarray that restricts the usage to only + those things that are required by the array API namespace. Note, + attributes on this object that start with a single underscore are not part + of the API specification and should only be used internally. This object + should not be constructed directly. Rather, use one of the creation + functions, such as asarray(). + + """ + # Use a custom constructor instead of __init__, as manually initializing + # this class is not supported API. + @classmethod + def _new(cls, x, /): + """ + This is a private method for initializing the array API Array + object. + + Functions outside of the array_api submodule should not use this + method. Use one of the creation functions instead, such as + ``asarray``. + + """ + obj = super().__new__(cls) + # Note: The spec does not have array scalars, only 0-D arrays. + if isinstance(x, np.generic): + # Convert the array scalar to a 0-D array + x = np.asarray(x) + if x.dtype not in _all_dtypes: + raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'") + obj._array = x + return obj + + # Prevent Array() from working + def __new__(cls, *args, **kwargs): + raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") + + # These functions are not required by the spec, but are implemented for + # the sake of usability. + + def __str__(self: Array, /) -> str: + """ + Performs the operation __str__. + """ + return self._array.__str__().replace('array', 'Array') + + def __repr__(self: Array, /) -> str: + """ + Performs the operation __repr__. + """ + return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + + # These are various helper functions to make the array behavior match the + # spec in places where it either deviates from or is more strict than + # NumPy behavior + + def _check_allowed_dtypes(self, other, dtype_category, op): + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + from ._dtypes import _result_type + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, + } + + if self.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + else: + return NotImplemented + + # This will raise TypeError for type combinations that are not allowed + # to promote in the spec (even if the NumPy array operator would + # promote them). + res_dtype = _result_type(self.dtype, other.dtype) + if op.startswith('__i'): + # Note: NumPy will allow in-place operators in some cases where + # the type promoted operator does not match the left-hand side + # operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + + # The spec explicitly disallows this. + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + + # Helper function to match the type promotion rules in the spec + def _promote_scalar(self, scalar): + """ + Returns a promoted version of a Python scalar appropriate for use with + operations on self. + + This may raise an OverflowError in cases where the scalar is an + integer that is too large to fit in a NumPy integer dtype, or + TypeError when the scalar type is incompatible with the dtype of self. + """ + if isinstance(scalar, bool): + if self.dtype not in _boolean_dtypes: + raise TypeError("Python bool scalars can only be promoted with bool arrays") + elif isinstance(scalar, int): + if self.dtype in _boolean_dtypes: + raise TypeError("Python int scalars cannot be promoted with bool arrays") + elif isinstance(scalar, float): + if self.dtype not in _floating_dtypes: + raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + else: + raise TypeError("'scalar' must be a Python scalar") + + # Note: the spec only specifies integer-dtype/int promotion + # behavior for integers within the bounds of the integer dtype. + # Outside of those bounds we use the default NumPy behavior (either + # cast or raise OverflowError). + return Array._new(np.array(scalar, self.dtype)) + + @staticmethod + def _normalize_two_args(x1, x2): + """ + Normalize inputs to two arg functions to fix type promotion rules + + NumPy deviates from the spec type promotion rules in cases where one + argument is 0-dimensional and the other is not. For example: + + >>> import numpy as np + >>> a = np.array([1.0], dtype=np.float32) + >>> b = np.array(1.0, dtype=np.float64) + >>> np.add(a, b) # The spec says this should be float64 + array([2.], dtype=float32) + + To fix this, we add a dimension to the 0-dimension array before passing it + through. This works because a dimension would be added anyway from + broadcasting, so the resulting shape is the same, but this prevents NumPy + from not promoting the dtype. + """ + # Another option would be to use signature=(x1.dtype, x2.dtype, None), + # but that only works for ufuncs, so we would have to call the ufuncs + # directly in the operator methods. One should also note that this + # sort of trick wouldn't work for functions like searchsorted, which + # don't do normal broadcasting, but there aren't any functions like + # that in the array API namespace. + if x1.ndim == 0 and x2.ndim != 0: + # The _array[None] workaround was chosen because it is relatively + # performant. broadcast_to(x1._array, x2.shape) is much slower. We + # could also manually type promote x2, but that is more complicated + # and about the same performance as this. + x1 = Array._new(x1._array[None]) + elif x2.ndim == 0 and x1.ndim != 0: + x2 = Array._new(x2._array[None]) + return (x1, x2) + + # Note: A large fraction of allowed indices are disallowed here (see the + # docstring below) + @staticmethod + def _validate_index(key, shape): + """ + Validate an index according to the array API. + + The array API specification only requires a subset of indices that are + supported by NumPy. This function will reject any index that is + allowed by NumPy but not required by the array API specification. We + always raise ``IndexError`` on such indices (the spec does not require + any specific behavior on them, but this makes the NumPy array API + namespace a minimal implementation of the spec). See + https://data-apis.org/array-api/latest/API_specification/indexing.html + for the full list of required indexing behavior + + This function either raises IndexError if the index ``key`` is + invalid, or a new key to be used in place of ``key`` in indexing. It + only raises ``IndexError`` on indices that are not already rejected by + NumPy, as NumPy will already raise the appropriate error on such + indices. ``shape`` may be None, in which case, only cases that are + independent of the array shape are checked. + + The following cases are allowed by NumPy, but not specified by the array + API specification: + + - The start and stop of a slice may not be out of bounds. In + particular, for a slice ``i:j:k`` on an axis of size ``n``, only the + following are allowed: + + - ``i`` or ``j`` omitted (``None``). + - ``-n <= i <= max(0, n - 1)``. + - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``. + - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``. + + - Boolean array indices are not allowed as part of a larger tuple + index. + + - Integer array indices are not allowed (with the exception of 0-D + arrays, which are treated the same as scalars). + + Additionally, it should be noted that indices that would return a + scalar in NumPy will return a 0-D array. Array scalars are not allowed + in the specification, only 0-D arrays. This is done in the + ``Array._new`` constructor, not this function. + + """ + if isinstance(key, slice): + if shape is None: + return key + if shape == (): + return key + size = shape[0] + # Ensure invalid slice entries are passed through. + if key.start is not None: + try: + operator.index(key.start) + except TypeError: + return key + if not (-size <= key.start <= max(0, size - 1)): + raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + if key.stop is not None: + try: + operator.index(key.stop) + except TypeError: + return key + step = 1 if key.step is None else key.step + if (step > 0 and not (-size <= key.stop <= size) + or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): + raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") + return key + + elif isinstance(key, tuple): + key = tuple(Array._validate_index(idx, None) for idx in key) + + for idx in key: + if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if len(key) == 1: + return key + raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + if isinstance(idx, tuple): + raise IndexError("Nested tuple indices are not allowed in the array API namespace") + + if shape is None: + return key + n_ellipsis = key.count(...) + if n_ellipsis > 1: + return key + ellipsis_i = key.index(...) if n_ellipsis else len(key) + + for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + Array._validate_index(idx, (size,)) + return key + elif isinstance(key, bool): + return key + elif isinstance(key, Array): + if key.dtype in _integer_dtypes: + if key.ndim != 0: + raise IndexError("Non-zero dimensional integer array indices are not allowed in the array API namespace") + return key._array + elif key is Ellipsis: + return key + elif key is None: + raise IndexError("newaxis indices are not allowed in the array API namespace") + try: + return operator.index(key) + except TypeError: + # Note: This also omits boolean arrays that are not already in + # Array() form, like a list of booleans. + raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + + # Everything below this line is required by the spec. + + def __abs__(self: Array, /) -> Array: + """ + Performs the operation __abs__. + """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __abs__') + res = self._array.__abs__() + return self.__class__._new(res) + + def __add__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __add__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__add__(other._array) + return self.__class__._new(res) + + def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __and__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__and__(other._array) + return self.__class__._new(res) + + def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object: + if api_version is not None and not api_version.startswith('2021.'): + raise ValueError(f"Unrecognized array API version: {api_version!r}") + from numpy import array_api + return array_api + + def __bool__(self: Array, /) -> bool: + """ + Performs the operation __bool__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("bool is only allowed on arrays with 0 dimensions") + res = self._array.__bool__() + return res + + def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: + """ + Performs the operation __dlpack__. + """ + res = self._array.__dlpack__(stream=stream) + return self.__class__._new(res) + + def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: + """ + Performs the operation __dlpack_device__. + """ + # Note: device support is required for this + res = self._array.__dlpack_device__() + return self.__class__._new(res) + + def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + """ + Performs the operation __eq__. + """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, 'all', '__eq__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__eq__(other._array) + return self.__class__._new(res) + + def __float__(self: Array, /) -> float: + """ + Performs the operation __float__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("float is only allowed on arrays with 0 dimensions") + res = self._array.__float__() + return res + + def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __floordiv__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__floordiv__(other._array) + return self.__class__._new(res) + + def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __ge__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ge__(other._array) + return self.__class__._new(res) + + def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array: + """ + Performs the operation __getitem__. + """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = self._validate_index(key, self.shape) + res = self._array.__getitem__(key) + return self._new(res) + + def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __gt__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__gt__(other._array) + return self.__class__._new(res) + + def __int__(self: Array, /) -> int: + """ + Performs the operation __int__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("int is only allowed on arrays with 0 dimensions") + res = self._array.__int__() + return res + + def __invert__(self: Array, /) -> Array: + """ + Performs the operation __invert__. + """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in __invert__') + res = self._array.__invert__() + return self.__class__._new(res) + + def __le__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __le__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__le__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__le__(other._array) + return self.__class__._new(res) + + # Note: __len__ may end up being removed from the array API spec. + def __len__(self, /) -> int: + """ + Performs the operation __len__. + """ + res = self._array.__len__() + return self.__class__._new(res) + + def __lshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __lshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__lshift__(other._array) + return self.__class__._new(res) + + def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __lt__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__lt__(other._array) + return self.__class__._new(res) + + def __matmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __matmul__. + """ + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + if other is NotImplemented: + return other + res = self._array.__matmul__(other._array) + return self.__class__._new(res) + + def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __mod__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__mod__(other._array) + return self.__class__._new(res) + + def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __mul__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__mul__(other._array) + return self.__class__._new(res) + + def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + """ + Performs the operation __ne__. + """ + other = self._check_allowed_dtypes(other, 'all', '__ne__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ne__(other._array) + return self.__class__._new(res) + + def __neg__(self: Array, /) -> Array: + """ + Performs the operation __neg__. + """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __neg__') + res = self._array.__neg__() + return self.__class__._new(res) + + def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __or__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__or__(other._array) + return self.__class__._new(res) + + def __pos__(self: Array, /) -> Array: + """ + Performs the operation __pos__. + """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __pos__') + res = self._array.__pos__() + return self.__class__._new(res) + + # PEP 484 requires int to be a subtype of float, but __pow__ should not + # accept int. + def __pow__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __pow__. + """ + from ._elementwise_functions import pow + + other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + if other is NotImplemented: + return other + # Note: NumPy's __pow__ does not follow type promotion rules for 0-d + # arrays, so we use pow() here instead. + return pow(self, other) + + def __rshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rshift__(other._array) + return self.__class__._new(res) + + def __setitem__(self, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], value: Union[int, float, bool, Array], /) -> Array: + """ + Performs the operation __setitem__. + """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = self._validate_index(key, self.shape) + self._array.__setitem__(key, asarray(value)._array) + + def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __sub__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__sub__(other._array) + return self.__class__._new(res) + + # PEP 484 requires int to be a subtype of float, but __truediv__ should + # not accept int. + def __truediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __truediv__. + """ + other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__truediv__(other._array) + return self.__class__._new(res) + + def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __xor__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__xor__(other._array) + return self.__class__._new(res) + + def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __iadd__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + if other is NotImplemented: + return other + self._array.__iadd__(other._array) + return self + + def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __radd__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__radd__(other._array) + return self.__class__._new(res) + + def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __iand__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + if other is NotImplemented: + return other + self._array.__iand__(other._array) + return self + + def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __rand__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rand__(other._array) + return self.__class__._new(res) + + def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __ifloordiv__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + if other is NotImplemented: + return other + self._array.__ifloordiv__(other._array) + return self + + def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rfloordiv__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rfloordiv__(other._array) + return self.__class__._new(res) + + def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __ilshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + if other is NotImplemented: + return other + self._array.__ilshift__(other._array) + return self + + def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rlshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rlshift__(other._array) + return self.__class__._new(res) + + def __imatmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __imatmul__. + """ + # Note: NumPy does not implement __imatmul__. + + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + if other is NotImplemented: + return other + + # __imatmul__ can only be allowed when it would not change the shape + # of self. + other_shape = other.shape + if self.shape == () or other_shape == (): + raise ValueError("@= requires at least one dimension") + if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: + raise ValueError("@= cannot change the shape of the input array") + self._array[:] = self._array.__matmul__(other._array) + return self + + def __rmatmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __rmatmul__. + """ + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + if other is NotImplemented: + return other + res = self._array.__rmatmul__(other._array) + return self.__class__._new(res) + + def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __imod__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + if other is NotImplemented: + return other + self._array.__imod__(other._array) + return self + + def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rmod__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rmod__(other._array) + return self.__class__._new(res) + + def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __imul__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + if other is NotImplemented: + return other + self._array.__imul__(other._array) + return self + + def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rmul__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rmul__(other._array) + return self.__class__._new(res) + + def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ior__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + if other is NotImplemented: + return other + self._array.__ior__(other._array) + return self + + def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ror__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ror__(other._array) + return self.__class__._new(res) + + def __ipow__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __ipow__. + """ + other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + if other is NotImplemented: + return other + self._array.__ipow__(other._array) + return self + + def __rpow__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __rpow__. + """ + from ._elementwise_functions import pow + + other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + if other is NotImplemented: + return other + # Note: NumPy's __pow__ does not follow the spec type promotion rules + # for 0-d arrays, so we use pow() here instead. + return pow(other, self) + + def __irshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __irshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + if other is NotImplemented: + return other + self._array.__irshift__(other._array) + return self + + def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rrshift__. + """ + other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rrshift__(other._array) + return self.__class__._new(res) + + def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __isub__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + if other is NotImplemented: + return other + self._array.__isub__(other._array) + return self + + def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rsub__. + """ + other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rsub__(other._array) + return self.__class__._new(res) + + def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __itruediv__. + """ + other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + if other is NotImplemented: + return other + self._array.__itruediv__(other._array) + return self + + def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __rtruediv__. + """ + other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rtruediv__(other._array) + return self.__class__._new(res) + + def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ixor__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + if other is NotImplemented: + return other + self._array.__ixor__(other._array) + return self + + def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __rxor__. + """ + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rxor__(other._array) + return self.__class__._new(res) + + @property + def dtype(self) -> Dtype: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.dtype `. + + See its docstring for more information. + """ + return self._array.dtype + + @property + def device(self) -> Device: + return 'cpu' + + @property + def ndim(self) -> int: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.ndim `. + + See its docstring for more information. + """ + return self._array.ndim + + @property + def shape(self) -> Tuple[int, ...]: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.shape `. + + See its docstring for more information. + """ + return self._array.shape + + @property + def size(self) -> int: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.size `. + + See its docstring for more information. + """ + return self._array.size + + @property + def T(self) -> Array: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.T `. + + See its docstring for more information. + """ + return self._array.T diff --git a/numpy/array_api/_constants.py b/numpy/array_api/_constants.py new file mode 100644 index 000000000..9541941e7 --- /dev/null +++ b/numpy/array_api/_constants.py @@ -0,0 +1,6 @@ +import numpy as np + +e = np.e +inf = np.inf +nan = np.nan +pi = np.pi diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py new file mode 100644 index 000000000..acf78056a --- /dev/null +++ b/numpy/array_api/_creation_functions.py @@ -0,0 +1,216 @@ +from __future__ import annotations + + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +if TYPE_CHECKING: + from ._typing import (Array, Device, Dtype, NestedSequence, + SupportsDLPack, SupportsBufferProtocol) + from collections.abc import Sequence +from ._dtypes import _all_dtypes + +import numpy as np + +def _check_valid_dtype(dtype): + # Note: Only spelling dtypes as the dtype objects is supported. + + # We use this instead of "dtype in _all_dtypes" because the dtype objects + # define equality with the sorts of things we want to disallw. + for d in (None,) + _all_dtypes: + if dtype is d: + return + raise ValueError("dtype must be one of the supported dtypes") + +def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.asarray `. + + See its docstring for more information. + """ + # _array_object imports in this file are inside the functions to avoid + # circular imports + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + if copy is False: + # Note: copy=False is not yet implemented in np.asarray + raise NotImplementedError("copy=False is not yet implemented") + if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): + if copy is True: + return Array._new(np.array(obj._array, copy=True, dtype=dtype)) + return obj + if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63): + # Give a better error message in this case. NumPy would convert this + # to an object array. TODO: This won't handle large integers in lists. + raise OverflowError("Integer out of bounds for array dtypes") + res = np.asarray(obj, dtype=dtype) + return Array._new(res) + +def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arange `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) + +def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.empty `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.empty(shape, dtype=dtype)) + +def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.empty_like `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.empty_like(x._array, dtype=dtype)) + +def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.eye `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + +def from_dlpack(x: object, /) -> Array: + # Note: dlpack support is not yet implemented on Array + raise NotImplementedError("DLPack support is not yet implemented") + +def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.full `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + if isinstance(fill_value, Array) and fill_value.ndim == 0: + fill_value = fill_value._array + res = np.full(shape, fill_value, dtype=dtype) + if res.dtype not in _all_dtypes: + # This will happen if the fill value is not something that NumPy + # coerces to one of the acceptable dtypes. + raise TypeError("Invalid input to full") + return Array._new(res) + +def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.full_like `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + res = np.full_like(x._array, fill_value, dtype=dtype) + if res.dtype not in _all_dtypes: + # This will happen if the fill value is not something that NumPy + # coerces to one of the acceptable dtypes. + raise TypeError("Invalid input to full_like") + return Array._new(res) + +def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linspace `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) + +def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]: + """ + Array API compatible wrapper for :py:func:`np.meshgrid `. + + See its docstring for more information. + """ + from ._array_object import Array + return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)] + +def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.ones `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.ones(shape, dtype=dtype)) + +def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.ones_like `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.ones_like(x._array, dtype=dtype)) + +def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.zeros `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.zeros(shape, dtype=dtype)) + +def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.zeros_like `. + + See its docstring for more information. + """ + from ._array_object import Array + + _check_valid_dtype(dtype) + if device not in ['cpu', None]: + raise ValueError(f"Unsupported device {device!r}") + return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py new file mode 100644 index 000000000..17a00cc6d --- /dev/null +++ b/numpy/array_api/_data_type_functions.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _all_dtypes, _result_type + +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Tuple, Union +if TYPE_CHECKING: + from ._typing import Dtype + from collections.abc import Sequence + +import numpy as np + +def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: + """ + Array API compatible wrapper for :py:func:`np.broadcast_arrays `. + + See its docstring for more information. + """ + from ._array_object import Array + return [Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])] + +def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: + """ + Array API compatible wrapper for :py:func:`np.broadcast_to `. + + See its docstring for more information. + """ + from ._array_object import Array + return Array._new(np.broadcast_to(x._array, shape)) + +def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: + """ + Array API compatible wrapper for :py:func:`np.can_cast `. + + See its docstring for more information. + """ + from ._array_object import Array + if isinstance(from_, Array): + from_ = from_._array + return np.can_cast(from_, to) + +# These are internal objects for the return types of finfo and iinfo, since +# the NumPy versions contain extra data that isn't part of the spec. +@dataclass +class finfo_object: + bits: int + # Note: The types of the float data here are float, whereas in NumPy they + # are scalars of the corresponding float dtype. + eps: float + max: float + min: float + # Note: smallest_normal is part of the array API spec, but cannot be used + # until https://github.com/numpy/numpy/pull/18536 is merged. + + # smallest_normal: float + +@dataclass +class iinfo_object: + bits: int + max: int + min: int + +def finfo(type: Union[Dtype, Array], /) -> finfo_object: + """ + Array API compatible wrapper for :py:func:`np.finfo `. + + See its docstring for more information. + """ + fi = np.finfo(type) + # Note: The types of the float data here are float, whereas in NumPy they + # are scalars of the corresponding float dtype. + return finfo_object( + fi.bits, + float(fi.eps), + float(fi.max), + float(fi.min), + # TODO: Uncomment this when #18536 is merged. + # float(fi.smallest_normal), + ) + +def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: + """ + Array API compatible wrapper for :py:func:`np.iinfo `. + + See its docstring for more information. + """ + ii = np.iinfo(type) + return iinfo_object(ii.bits, ii.max, ii.min) + +def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: + """ + Array API compatible wrapper for :py:func:`np.result_type `. + + See its docstring for more information. + """ + # Note: we use a custom implementation that gives only the type promotions + # required by the spec rather than using np.result_type. NumPy implements + # too many extra type promotions like int64 + uint64 -> float64, and does + # value-based casting on scalar arrays. + A = [] + for a in arrays_and_dtypes: + if isinstance(a, Array): + a = a.dtype + elif isinstance(a, np.ndarray) or a not in _all_dtypes: + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + A.append(a) + + if len(A) == 0: + raise ValueError("at least one array or dtype is required") + elif len(A) == 1: + return A[0] + else: + t = A[0] + for t2 in A[1:]: + t = _result_type(t, t2) + return t diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py new file mode 100644 index 000000000..fcdb562da --- /dev/null +++ b/numpy/array_api/_dtypes.py @@ -0,0 +1,100 @@ +import numpy as np + +# Note: we use dtype objects instead of dtype classes. The spec does not +# require any behavior on dtypes other than equality. +int8 = np.dtype('int8') +int16 = np.dtype('int16') +int32 = np.dtype('int32') +int64 = np.dtype('int64') +uint8 = np.dtype('uint8') +uint16 = np.dtype('uint16') +uint32 = np.dtype('uint32') +uint64 = np.dtype('uint64') +float32 = np.dtype('float32') +float64 = np.dtype('float64') +# Note: This name is changed +bool = np.dtype('bool') + +_all_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float32, float64, bool) +_boolean_dtypes = (bool,) +_floating_dtypes = (float32, float64) +_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) + +# Note: the spec defines a restricted type promotion table compared to NumPy. +# In particular, cross-kind promotions like integer + float or boolean + +# integer are not allowed, even for functions that accept both kinds. +# Additionally, NumPy promotes signed integer + uint64 to float64, but this +# promotion is not allowed here. To be clear, Python scalar int objects are +# allowed to promote to floating-point dtypes, but only in array operators +# (see Array._promote_scalar) method in _array_object.py. +_promotion_table = { + (int8, int8): int8, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int16, int8): int16, + (int16, int16): int16, + (int16, int32): int32, + (int16, int64): int64, + (int32, int8): int32, + (int32, int16): int32, + (int32, int32): int32, + (int32, int64): int64, + (int64, int8): int64, + (int64, int16): int64, + (int64, int32): int64, + (int64, int64): int64, + (uint8, uint8): uint8, + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, uint8): uint16, + (uint16, uint16): uint16, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, uint8): uint32, + (uint32, uint16): uint32, + (uint32, uint32): uint32, + (uint32, uint64): uint64, + (uint64, uint8): uint64, + (uint64, uint16): uint64, + (uint64, uint32): uint64, + (uint64, uint64): uint64, + (int8, uint8): int16, + (int8, uint16): int32, + (int8, uint32): int64, + (int16, uint8): int16, + (int16, uint16): int32, + (int16, uint32): int64, + (int32, uint8): int32, + (int32, uint16): int32, + (int32, uint32): int64, + (int64, uint8): int64, + (int64, uint16): int64, + (int64, uint32): int64, + (uint8, int8): int16, + (uint16, int8): int32, + (uint32, int8): int64, + (uint8, int16): int16, + (uint16, int16): int32, + (uint32, int16): int64, + (uint8, int32): int32, + (uint16, int32): int32, + (uint32, int32): int64, + (uint8, int64): int64, + (uint16, int64): int64, + (uint32, int64): int64, + (float32, float32): float32, + (float32, float64): float64, + (float64, float32): float64, + (float64, float64): float64, + (bool, bool): bool, +} + +def _result_type(type1, type2): + if (type1, type2) in _promotion_table: + return _promotion_table[type1, type2] + raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py new file mode 100644 index 000000000..7833ebe54 --- /dev/null +++ b/numpy/array_api/_elementwise_functions.py @@ -0,0 +1,659 @@ +from __future__ import annotations + +from ._dtypes import (_boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes, _result_type) +from ._array_object import Array + +import numpy as np + +def abs(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.abs `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in abs') + return Array._new(np.abs(x._array)) + +# Note: the function name is different here +def acos(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arccos `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acos') + return Array._new(np.arccos(x._array)) + +# Note: the function name is different here +def acosh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arccosh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acosh') + return Array._new(np.arccosh(x._array)) + +def add(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.add `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in add') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.add(x1._array, x2._array)) + +# Note: the function name is different here +def asin(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arcsin `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asin') + return Array._new(np.arcsin(x._array)) + +# Note: the function name is different here +def asinh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arcsinh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asinh') + return Array._new(np.arcsinh(x._array)) + +# Note: the function name is different here +def atan(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arctan `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan') + return Array._new(np.arctan(x._array)) + +# Note: the function name is different here +def atan2(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arctan2 `. + + See its docstring for more information. + """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan2') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.arctan2(x1._array, x2._array)) + +# Note: the function name is different here +def atanh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.arctanh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atanh') + return Array._new(np.arctanh(x._array)) + +def bitwise_and(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.bitwise_and `. + + See its docstring for more information. + """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_and(x1._array, x2._array)) + +# Note: the function name is different here +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.left_shift `. + + See its docstring for more information. + """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + # Note: bitwise_left_shift is only defined for x2 nonnegative. + if np.any(x2._array < 0): + raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0') + return Array._new(np.left_shift(x1._array, x2._array)) + +# Note: the function name is different here +def bitwise_invert(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.invert `. + + See its docstring for more information. + """ + if x.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') + return Array._new(np.invert(x._array)) + +def bitwise_or(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.bitwise_or `. + + See its docstring for more information. + """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_or(x1._array, x2._array)) + +# Note: the function name is different here +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.right_shift `. + + See its docstring for more information. + """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + # Note: bitwise_right_shift is only defined for x2 nonnegative. + if np.any(x2._array < 0): + raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0') + return Array._new(np.right_shift(x1._array, x2._array)) + +def bitwise_xor(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.bitwise_xor `. + + See its docstring for more information. + """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_xor(x1._array, x2._array)) + +def ceil(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.ceil `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in ceil') + if x.dtype in _integer_dtypes: + # Note: The return dtype of ceil is the same as the input + return x + return Array._new(np.ceil(x._array)) + +def cos(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.cos `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cos') + return Array._new(np.cos(x._array)) + +def cosh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.cosh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cosh') + return Array._new(np.cosh(x._array)) + +def divide(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.divide `. + + See its docstring for more information. + """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in divide') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.divide(x1._array, x2._array)) + +def equal(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.equal `. + + See its docstring for more information. + """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.equal(x1._array, x2._array)) + +def exp(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.exp `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in exp') + return Array._new(np.exp(x._array)) + +def expm1(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.expm1 `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in expm1') + return Array._new(np.expm1(x._array)) + +def floor(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.floor `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor') + if x.dtype in _integer_dtypes: + # Note: The return dtype of floor is the same as the input + return x + return Array._new(np.floor(x._array)) + +def floor_divide(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.floor_divide `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor_divide') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.floor_divide(x1._array, x2._array)) + +def greater(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.greater `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.greater(x1._array, x2._array)) + +def greater_equal(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.greater_equal `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater_equal') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.greater_equal(x1._array, x2._array)) + +def isfinite(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.isfinite `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isfinite') + return Array._new(np.isfinite(x._array)) + +def isinf(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.isinf `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isinf') + return Array._new(np.isinf(x._array)) + +def isnan(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.isnan `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isnan') + return Array._new(np.isnan(x._array)) + +def less(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.less `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.less(x1._array, x2._array)) + +def less_equal(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.less_equal `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less_equal') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.less_equal(x1._array, x2._array)) + +def log(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.log `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log') + return Array._new(np.log(x._array)) + +def log1p(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.log1p `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log1p') + return Array._new(np.log1p(x._array)) + +def log2(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.log2 `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log2') + return Array._new(np.log2(x._array)) + +def log10(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.log10 `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log10') + return Array._new(np.log10(x._array)) + +def logaddexp(x1: Array, x2: Array) -> Array: + """ + Array API compatible wrapper for :py:func:`np.logaddexp `. + + See its docstring for more information. + """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in logaddexp') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logaddexp(x1._array, x2._array)) + +def logical_and(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.logical_and `. + + See its docstring for more information. + """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_and') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_and(x1._array, x2._array)) + +def logical_not(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.logical_not `. + + See its docstring for more information. + """ + if x.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_not') + return Array._new(np.logical_not(x._array)) + +def logical_or(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.logical_or `. + + See its docstring for more information. + """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_or') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_or(x1._array, x2._array)) + +def logical_xor(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.logical_xor `. + + See its docstring for more information. + """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_xor') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_xor(x1._array, x2._array)) + +def multiply(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.multiply `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in multiply') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.multiply(x1._array, x2._array)) + +def negative(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.negative `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in negative') + return Array._new(np.negative(x._array)) + +def not_equal(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.not_equal `. + + See its docstring for more information. + """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.not_equal(x1._array, x2._array)) + +def positive(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.positive `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in positive') + return Array._new(np.positive(x._array)) + +# Note: the function name is different here +def pow(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.power `. + + See its docstring for more information. + """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in pow') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.power(x1._array, x2._array)) + +def remainder(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.remainder `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in remainder') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.remainder(x1._array, x2._array)) + +def round(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.round `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in round') + return Array._new(np.round(x._array)) + +def sign(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sign `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in sign') + return Array._new(np.sign(x._array)) + +def sin(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sin `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sin') + return Array._new(np.sin(x._array)) + +def sinh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sinh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sinh') + return Array._new(np.sinh(x._array)) + +def square(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.square `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in square') + return Array._new(np.square(x._array)) + +def sqrt(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sqrt `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sqrt') + return Array._new(np.sqrt(x._array)) + +def subtract(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.subtract `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in subtract') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.subtract(x1._array, x2._array)) + +def tan(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tan `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tan') + return Array._new(np.tan(x._array)) + +def tanh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tanh `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tanh') + return Array._new(np.tanh(x._array)) + +def trunc(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.trunc `. + + See its docstring for more information. + """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trunc') + if x.dtype in _integer_dtypes: + # Note: The return dtype of trunc is the same as the input + return x + return Array._new(np.trunc(x._array)) diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py new file mode 100644 index 000000000..f13f9c541 --- /dev/null +++ b/numpy/array_api/_linear_algebra_functions.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _numeric_dtypes, _result_type + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +# einsum is not yet implemented in the array API spec. + +# def einsum(): +# """ +# Array API compatible wrapper for :py:func:`np.einsum `. +# +# See its docstring for more information. +# """ +# return np.einsum() + +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + + return Array._new(np.matmul(x1._array, x2._array)) + +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + +def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.transpose `. + + See its docstring for more information. + """ + return Array._new(np.transpose(x._array, axes=axes)) + +# Note: vecdot is not in NumPy +def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: + if axis is None: + axis = -1 + return tensordot(x1, x2, axes=((axis,), (axis,))) diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py new file mode 100644 index 000000000..fa6344beb --- /dev/null +++ b/numpy/array_api/_manipulation_functions.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from ._array_object import Array +from ._data_type_functions import result_type + +from typing import List, Optional, Tuple, Union + +import numpy as np + +# Note: the function name is different here +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.concatenate `. + + See its docstring for more information. + """ + arrays = tuple(a._array for a in arrays) + # Call result type here just to raise on disallowed type combinations + result_type(*arrays) + return Array._new(np.concatenate(arrays, axis=axis)) + +def expand_dims(x: Array, /, *, axis: int) -> Array: + """ + Array API compatible wrapper for :py:func:`np.expand_dims `. + + See its docstring for more information. + """ + return Array._new(np.expand_dims(x._array, axis)) + +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.flip `. + + See its docstring for more information. + """ + return Array._new(np.flip(x._array, axis=axis)) + +def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: + """ + Array API compatible wrapper for :py:func:`np.reshape `. + + See its docstring for more information. + """ + return Array._new(np.reshape(x._array, shape)) + +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.roll `. + + See its docstring for more information. + """ + return Array._new(np.roll(x._array, shift, axis=axis)) + +def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.squeeze `. + + See its docstring for more information. + """ + return Array._new(np.squeeze(x._array, axis=axis)) + +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.stack `. + + See its docstring for more information. + """ + arrays = tuple(a._array for a in arrays) + # Call result type here just to raise on disallowed type combinations + result_type(*arrays) + return Array._new(np.stack(arrays, axis=axis)) diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py new file mode 100644 index 000000000..d80720850 --- /dev/null +++ b/numpy/array_api/_searching_functions.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _result_type + +from typing import Optional, Tuple + +import numpy as np + +def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.argmax `. + + See its docstring for more information. + """ + # Note: this currently fails as np.argmax does not implement keepdims + return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + +def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.argmin `. + + See its docstring for more information. + """ + # Note: this currently fails as np.argmin does not implement keepdims + return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + +def nonzero(x: Array, /) -> Tuple[Array, ...]: + """ + Array API compatible wrapper for :py:func:`np.nonzero `. + + See its docstring for more information. + """ + return Array._new(np.nonzero(x._array)) + +def where(condition: Array, x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.where `. + + See its docstring for more information. + """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + return Array._new(np.where(condition._array, x1._array, x2._array)) diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py new file mode 100644 index 000000000..f28c2ee72 --- /dev/null +++ b/numpy/array_api/_set_functions.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from ._array_object import Array + +from typing import Tuple, Union + +import numpy as np + +def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False) -> Union[Array, Tuple[Array, ...]]: + """ + Array API compatible wrapper for :py:func:`np.unique `. + + See its docstring for more information. + """ + return Array._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py new file mode 100644 index 000000000..a125e0718 --- /dev/null +++ b/numpy/array_api/_sorting_functions.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from ._array_object import Array + +import numpy as np + +def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + """ + Array API compatible wrapper for :py:func:`np.argsort `. + + See its docstring for more information. + """ + # Note: this keyword argument is different, and the default is different. + kind = 'stable' if stable else 'quicksort' + res = np.argsort(x._array, axis=axis, kind=kind) + if descending: + res = np.flip(res, axis=axis) + return Array._new(res) + +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + """ + Array API compatible wrapper for :py:func:`np.sort `. + + See its docstring for more information. + """ + # Note: this keyword argument is different, and the default is different. + kind = 'stable' if stable else 'quicksort' + res = np.sort(x._array, axis=axis, kind=kind) + if descending: + res = np.flip(res, axis=axis) + return Array._new(res) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py new file mode 100644 index 000000000..61fc60c46 --- /dev/null +++ b/numpy/array_api/_statistical_functions.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from ._array_object import Array + +from typing import Optional, Tuple, Union + +import numpy as np + +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) + +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + return Array._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) + +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) + +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + return Array._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) + +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + # Note: the keyword argument correction is different here + return Array._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) + +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + return Array._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) + +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + # Note: the keyword argument correction is different here + return Array._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py new file mode 100644 index 000000000..4ff718205 --- /dev/null +++ b/numpy/array_api/_typing.py @@ -0,0 +1,26 @@ +""" +This file defines the types for type annotations. + +These names aren't part of the module namespace, but they are used in the +annotations in the function signatures. The functions in the module are only +valid for inputs that match the given type annotations. +""" + +__all__ = ['Array', 'Device', 'Dtype', 'SupportsDLPack', + 'SupportsBufferProtocol', 'PyCapsule'] + +from typing import Any, Sequence, Type, Union + +from . import (Array, int8, int16, int32, int64, uint8, uint16, uint32, + uint64, float32, float64) + +# This should really be recursive, but that isn't supported yet. See the +# similar comment in numpy/typing/_array_like.py +NestedSequence = Sequence[Sequence[Any]] + +Device = Any +Dtype = Type[Union[[int8, int16, int32, int64, uint8, uint16, + uint32, uint64, float32, float64]]] +SupportsDLPack = Any +SupportsBufferProtocol = Any +PyCapsule = Any diff --git a/numpy/array_api/_utility_functions.py b/numpy/array_api/_utility_functions.py new file mode 100644 index 000000000..f243bfe68 --- /dev/null +++ b/numpy/array_api/_utility_functions.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from ._array_object import Array + +from typing import Optional, Tuple, Union + +import numpy as np + +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.all `. + + See its docstring for more information. + """ + return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) + +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.any `. + + See its docstring for more information. + """ + return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) diff --git a/numpy/array_api/tests/__init__.py b/numpy/array_api/tests/__init__.py new file mode 100644 index 000000000..536062e38 --- /dev/null +++ b/numpy/array_api/tests/__init__.py @@ -0,0 +1,7 @@ +""" +Tests for the array API namespace. + +Note, full compliance with the array API can be tested with the official array API test +suite https://github.com/data-apis/array-api-tests. This test suite primarily +focuses on those things that are not tested by the official test suite. +""" diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py new file mode 100644 index 000000000..22078bbee --- /dev/null +++ b/numpy/array_api/tests/test_array_object.py @@ -0,0 +1,250 @@ +from numpy.testing import assert_raises +import numpy as np + +from .. import ones, asarray, result_type +from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes, int8, int16, int32, int64, uint64) + +def test_validate_index(): + # The indexing tests in the official array API test suite test that the + # array object correctly handles the subset of indices that are required + # by the spec. But the NumPy array API implementation specifically + # disallows any index not required by the spec, via Array._validate_index. + # This test focuses on testing that non-valid indices are correctly + # rejected. See + # https://data-apis.org/array-api/latest/API_specification/indexing.html + # and the docstring of Array._validate_index for the exact indexing + # behavior that should be allowed. This does not test indices that are + # already invalid in NumPy itself because Array will generally just pass + # such indices directly to the underlying np.ndarray. + + a = ones((3, 4)) + + # Out of bounds slices are not allowed + assert_raises(IndexError, lambda: a[:4]) + assert_raises(IndexError, lambda: a[:-4]) + assert_raises(IndexError, lambda: a[:3:-1]) + assert_raises(IndexError, lambda: a[:-5:-1]) + assert_raises(IndexError, lambda: a[3:]) + assert_raises(IndexError, lambda: a[-4:]) + assert_raises(IndexError, lambda: a[3::-1]) + assert_raises(IndexError, lambda: a[-4::-1]) + + assert_raises(IndexError, lambda: a[...,:5]) + assert_raises(IndexError, lambda: a[...,:-5]) + assert_raises(IndexError, lambda: a[...,:4:-1]) + assert_raises(IndexError, lambda: a[...,:-6:-1]) + assert_raises(IndexError, lambda: a[...,4:]) + assert_raises(IndexError, lambda: a[...,-5:]) + assert_raises(IndexError, lambda: a[...,4::-1]) + assert_raises(IndexError, lambda: a[...,-5::-1]) + + # Boolean indices cannot be part of a larger tuple index + assert_raises(IndexError, lambda: a[a[:,0]==1,0]) + assert_raises(IndexError, lambda: a[a[:,0]==1,...]) + assert_raises(IndexError, lambda: a[..., a[0]==1]) + assert_raises(IndexError, lambda: a[[True, True, True]]) + assert_raises(IndexError, lambda: a[(True, True, True),]) + + # Integer array indices are not allowed (except for 0-D) + idx = asarray([[0, 1]]) + assert_raises(IndexError, lambda: a[idx]) + assert_raises(IndexError, lambda: a[idx,]) + assert_raises(IndexError, lambda: a[[0, 1]]) + assert_raises(IndexError, lambda: a[(0, 1), (0, 1)]) + assert_raises(IndexError, lambda: a[[0, 1]]) + assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) + + # np.newaxis is not allowed + assert_raises(IndexError, lambda: a[None]) + assert_raises(IndexError, lambda: a[None, ...]) + assert_raises(IndexError, lambda: a[..., None]) + +def test_operators(): + # For every operator, we test that it works for the required type + # combinations and raises TypeError otherwise + binary_op_dtypes ={ + '__add__': 'numeric', + '__and__': 'integer_or_boolean', + '__eq__': 'all', + '__floordiv__': 'numeric', + '__ge__': 'numeric', + '__gt__': 'numeric', + '__le__': 'numeric', + '__lshift__': 'integer', + '__lt__': 'numeric', + '__mod__': 'numeric', + '__mul__': 'numeric', + '__ne__': 'all', + '__or__': 'integer_or_boolean', + '__pow__': 'floating', + '__rshift__': 'integer', + '__sub__': 'numeric', + '__truediv__': 'floating', + '__xor__': 'integer_or_boolean', + } + + # Recompute each time because of in-place ops + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1., dtype=d) + + for op, dtypes in binary_op_dtypes.items(): + ops = [op] + if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']: + rop = '__r' + op[2:] + iop = '__i' + op[2:] + ops += [rop, iop] + for s in [1, 1., False]: + for _op in ops: + for a in _array_vals(): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for floating-point array dtypes + + # We do not do bounds checking for int scalars, but rather use the default + # NumPy behavior for casting in that case. + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating" and a.dtype in _floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _floating_dtypes and type(s) in [float, int] + )): + # Only test for no error + getattr(a, _op)(s) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + # Test array op array. + for _op in ops: + for x in _array_vals(): + for y in _array_vals(): + # See the promotion table in NEP 47 or the array + # API spec page on type promotion. Mixed kind + # promotion is not defined. + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure in-place operators only promote to the same dtype as the left operand. + elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure only those dtypes that are required for every operator are allowed. + elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) + or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + ): + getattr(x, _op)(y) + else: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + + unary_op_dtypes ={ + '__abs__': 'numeric', + '__invert__': 'integer_or_boolean', + '__neg__': 'numeric', + '__pos__': 'numeric', + } + for op, dtypes in unary_op_dtypes.items(): + for a in _array_vals(): + if (dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + ): + # Only test for no error + getattr(a, op)() + else: + assert_raises(TypeError, lambda: getattr(a, op)()) + + # Finally, matmul() must be tested separately, because it works a bit + # different from the other operations. + def _matmul_array_vals(): + for a in _array_vals(): + yield a + for d in _all_dtypes: + yield ones((3, 4), dtype=d) + yield ones((4, 2), dtype=d) + yield ones((4, 4), dtype=d) + + # Scalars always error + for _op in ['__matmul__', '__rmatmul__', '__imatmul__']: + for s in [1, 1., False]: + for a in _matmul_array_vals(): + if (type(s) in [float, int] and a.dtype in _floating_dtypes + or type(s) == int and a.dtype in _integer_dtypes): + # Type promotion is valid, but @ is not allowed on 0-D + # inputs, so the error is a ValueError + assert_raises(ValueError, lambda: getattr(a, _op)(s)) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + for x in _matmul_array_vals(): + for y in _matmul_array_vals(): + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + or x.dtype in _boolean_dtypes + or y.dtype in _boolean_dtypes + ): + assert_raises(TypeError, lambda: x.__matmul__(y)) + assert_raises(TypeError, lambda: y.__rmatmul__(x)) + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]: + assert_raises(ValueError, lambda: x.__matmul__(y)) + assert_raises(ValueError, lambda: y.__rmatmul__(x)) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + else: + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__matmul__(y) + y.__rmatmul__(x) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif y.shape[0] != y.shape[1]: + # This one fails because x @ y has a different shape from x + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__imatmul__(y) + +def test_python_scalar_construtors(): + a = asarray(False) + b = asarray(0) + c = asarray(0.) + + assert bool(a) == bool(b) == bool(c) == False + assert int(a) == int(b) == int(c) == 0 + assert float(a) == float(b) == float(c) == 0. + + # bool/int/float should only be allowed on 0-D arrays. + assert_raises(TypeError, lambda: bool(asarray([False]))) + assert_raises(TypeError, lambda: int(asarray([0]))) + assert_raises(TypeError, lambda: float(asarray([0.]))) diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py new file mode 100644 index 000000000..654f1d9b3 --- /dev/null +++ b/numpy/array_api/tests/test_creation_functions.py @@ -0,0 +1,103 @@ +from numpy.testing import assert_raises +import numpy as np + +from .. import all +from .._creation_functions import (asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like) +from .._array_object import Array +from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes, int8, int16, int32, int64, uint64) + +def test_asarray_errors(): + # Test various protections against incorrect usage + assert_raises(TypeError, lambda: Array([1])) + assert_raises(TypeError, lambda: asarray(['a'])) + assert_raises(ValueError, lambda: asarray([1.], dtype=np.float16)) + assert_raises(OverflowError, lambda: asarray(2**100)) + # Preferably this would be OverflowError + # assert_raises(OverflowError, lambda: asarray([2**100])) + assert_raises(TypeError, lambda: asarray([2**100])) + asarray([1], device='cpu') # Doesn't error + assert_raises(ValueError, lambda: asarray([1], device='gpu')) + + assert_raises(ValueError, lambda: asarray([1], dtype=int)) + assert_raises(ValueError, lambda: asarray([1], dtype='i')) + +def test_asarray_copy(): + a = asarray([1]) + b = asarray(a, copy=True) + a[0] = 0 + assert all(b[0] == 1) + assert all(a[0] == 0) + # Once copy=False is implemented, replace this with + # a = asarray([1]) + # b = asarray(a, copy=False) + # a[0] = 0 + # assert all(b[0] == 0) + assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) + +def test_arange_errors(): + arange(1, device='cpu') # Doesn't error + assert_raises(ValueError, lambda: arange(1, device='gpu')) + assert_raises(ValueError, lambda: arange(1, dtype=int)) + assert_raises(ValueError, lambda: arange(1, dtype='i')) + +def test_empty_errors(): + empty((1,), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: empty((1,), device='gpu')) + assert_raises(ValueError, lambda: empty((1,), dtype=int)) + assert_raises(ValueError, lambda: empty((1,), dtype='i')) + +def test_empty_like_errors(): + empty_like(asarray(1), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: empty_like(asarray(1), device='gpu')) + assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int)) + assert_raises(ValueError, lambda: empty_like(asarray(1), dtype='i')) + +def test_eye_errors(): + eye(1, device='cpu') # Doesn't error + assert_raises(ValueError, lambda: eye(1, device='gpu')) + assert_raises(ValueError, lambda: eye(1, dtype=int)) + assert_raises(ValueError, lambda: eye(1, dtype='i')) + +def test_full_errors(): + full((1,), 0, device='cpu') # Doesn't error + assert_raises(ValueError, lambda: full((1,), 0, device='gpu')) + assert_raises(ValueError, lambda: full((1,), 0, dtype=int)) + assert_raises(ValueError, lambda: full((1,), 0, dtype='i')) + +def test_full_like_errors(): + full_like(asarray(1), 0, device='cpu') # Doesn't error + assert_raises(ValueError, lambda: full_like(asarray(1), 0, device='gpu')) + assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int)) + assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype='i')) + +def test_linspace_errors(): + linspace(0, 1, 10, device='cpu') # Doesn't error + assert_raises(ValueError, lambda: linspace(0, 1, 10, device='gpu')) + assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float)) + assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype='f')) + +def test_ones_errors(): + ones((1,), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: ones((1,), device='gpu')) + assert_raises(ValueError, lambda: ones((1,), dtype=int)) + assert_raises(ValueError, lambda: ones((1,), dtype='i')) + +def test_ones_like_errors(): + ones_like(asarray(1), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: ones_like(asarray(1), device='gpu')) + assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int)) + assert_raises(ValueError, lambda: ones_like(asarray(1), dtype='i')) + +def test_zeros_errors(): + zeros((1,), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: zeros((1,), device='gpu')) + assert_raises(ValueError, lambda: zeros((1,), dtype=int)) + assert_raises(ValueError, lambda: zeros((1,), dtype='i')) + +def test_zeros_like_errors(): + zeros_like(asarray(1), device='cpu') # Doesn't error + assert_raises(ValueError, lambda: zeros_like(asarray(1), device='gpu')) + assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) + assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype='i')) diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py new file mode 100644 index 000000000..994cb0bf0 --- /dev/null +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -0,0 +1,110 @@ +from inspect import getfullargspec + +from numpy.testing import assert_raises + +from .. import asarray, _elementwise_functions +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift +from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes) + +def nargs(func): + return len(getfullargspec(func).args) + +def test_function_types(): + # Test that every function accepts only the required input types. We only + # test the negative cases here (error). The positive cases are tested in + # the array API test suite. + + elementwise_function_input_types = { + 'abs': 'numeric', + 'acos': 'floating', + 'acosh': 'floating', + 'add': 'numeric', + 'asin': 'floating', + 'asinh': 'floating', + 'atan': 'floating', + 'atan2': 'floating', + 'atanh': 'floating', + 'bitwise_and': 'integer_or_boolean', + 'bitwise_invert': 'integer_or_boolean', + 'bitwise_left_shift': 'integer', + 'bitwise_or': 'integer_or_boolean', + 'bitwise_right_shift': 'integer', + 'bitwise_xor': 'integer_or_boolean', + 'ceil': 'numeric', + 'cos': 'floating', + 'cosh': 'floating', + 'divide': 'floating', + 'equal': 'all', + 'exp': 'floating', + 'expm1': 'floating', + 'floor': 'numeric', + 'floor_divide': 'numeric', + 'greater': 'numeric', + 'greater_equal': 'numeric', + 'isfinite': 'numeric', + 'isinf': 'numeric', + 'isnan': 'numeric', + 'less': 'numeric', + 'less_equal': 'numeric', + 'log': 'floating', + 'logaddexp': 'floating', + 'log10': 'floating', + 'log1p': 'floating', + 'log2': 'floating', + 'logical_and': 'boolean', + 'logical_not': 'boolean', + 'logical_or': 'boolean', + 'logical_xor': 'boolean', + 'multiply': 'numeric', + 'negative': 'numeric', + 'not_equal': 'all', + 'positive': 'numeric', + 'pow': 'floating', + 'remainder': 'numeric', + 'round': 'numeric', + 'sign': 'numeric', + 'sin': 'floating', + 'sinh': 'floating', + 'sqrt': 'floating', + 'square': 'numeric', + 'subtract': 'numeric', + 'tan': 'floating', + 'tanh': 'floating', + 'trunc': 'numeric', + } + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer_or_boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating': _floating_dtypes, + } + + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1., dtype=d) + + for x in _array_vals(): + for func_name, types in elementwise_function_input_types.items(): + dtypes = _dtypes[types] + func = getattr(_elementwise_functions, func_name) + if nargs(func) == 2: + for y in _array_vals(): + if x.dtype not in dtypes or y.dtype not in dtypes: + assert_raises(TypeError, lambda: func(x, y)) + else: + if x.dtype not in dtypes: + assert_raises(TypeError, lambda: func(x)) + +def test_bitwise_shift_error(): + # bitwise shift functions should raise when the second argument is negative + assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) + assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) -- cgit v1.2.1 From ee852b432371e144456e012ec316c117170a7340 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 16:55:51 -0600 Subject: Print a warning when importing the numpy.array_api submodule --- numpy/array_api/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 4650e3db8..1e9790a14 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -115,6 +115,10 @@ Still TODO in this module are: """ +import warnings +warnings.warn("The numpy.array_api submodule is still experimental. See NEP 47.", + stacklevel=2) + __all__ = [] from ._constants import e, inf, nan, pi -- cgit v1.2.1 From 7e6a026f4dff0ebe49913a119f2555562e4e93be Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 19:46:21 -0600 Subject: Remove no longer comment about the keepdims argument to argmin --- numpy/array_api/_searching_functions.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index d80720850..de5f43f3d 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -13,7 +13,6 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - See its docstring for more information. """ - # Note: this currently fails as np.argmax does not implement keepdims return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: @@ -22,7 +21,6 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - See its docstring for more information. """ - # Note: this currently fails as np.argmin does not implement keepdims return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) def nonzero(x: Array, /) -> Tuple[Array, ...]: -- cgit v1.2.1 From c23abdc57b2e6c0fa4f939085374c01c1c4452a9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 19:57:46 -0600 Subject: Remove asarray() calls from the array API statistical functions asarray() is already called in Array._new. --- numpy/array_api/_statistical_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 61fc60c46..a606203bc 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -10,21 +10,21 @@ def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return Array._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return Array._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: # Note: the keyword argument correction is different here - return Array._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) + return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return Array._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: # Note: the keyword argument correction is different here - return Array._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) + return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) -- cgit v1.2.1 From 5605d687019dc55e594d4e227747c72bebb71a3c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 19:59:47 -0600 Subject: Remove unused import --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 24957fde6..af70058e6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -21,7 +21,7 @@ from ._creation_functions import asarray from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype -- cgit v1.2.1 From bc20d334b575f897157b1cf3eecda77f3e40e049 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 20:01:11 -0600 Subject: Move the array API dtype categories into the top level They are not an official part of the spec but are useful for various parts of the implementation. --- numpy/array_api/_array_object.py | 17 ++++------------- numpy/array_api/_dtypes.py | 10 ++++++++++ numpy/array_api/tests/test_elementwise_functions.py | 16 +++------------- 3 files changed, 17 insertions(+), 26 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index af70058e6..50906642d 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -98,23 +98,14 @@ class Array: if other is NotImplemented: return other """ - from ._dtypes import _result_type - - _dtypes = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer or boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating-point': _floating_dtypes, - } - - if self.dtype not in _dtypes[dtype_category]: + from ._dtypes import _result_type, _dtype_categories + + if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): - if other.dtype not in _dtypes[dtype_category]: + if other.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') else: return NotImplemented diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index fcdb562da..07be267da 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -23,6 +23,16 @@ _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) _numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_dtype_categories = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, +} + + # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index 994cb0bf0..2a5ddbc87 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -4,9 +4,8 @@ from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes) +from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, + _integer_dtypes) def nargs(func): return len(getfullargspec(func).args) @@ -75,15 +74,6 @@ def test_function_types(): 'trunc': 'numeric', } - _dtypes = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer_or_boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating': _floating_dtypes, - } - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -94,7 +84,7 @@ def test_function_types(): for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): - dtypes = _dtypes[types] + dtypes = _dtype_categories[types] func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): -- cgit v1.2.1 From 6789a74312cda391b81ca803d38919555213a38f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 20:05:46 -0600 Subject: Move some imports out of functions to the top of the file Some of the imports in the array API module have to be inside functions to avoid circular imports, but these ones did not. --- numpy/array_api/_array_object.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 50906642d..364b88f89 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -19,7 +19,8 @@ import operator from enum import IntEnum from ._creation_functions import asarray from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) + _integer_or_boolean_dtypes, _floating_dtypes, + _numeric_dtypes, _result_type, _dtype_categories) from typing import TYPE_CHECKING, Optional, Tuple, Union if TYPE_CHECKING: @@ -27,6 +28,8 @@ if TYPE_CHECKING: import numpy as np +from numpy import array_api + class Array: """ n-d array object for the array API namespace. @@ -98,7 +101,6 @@ class Array: if other is NotImplemented: return other """ - from ._dtypes import _result_type, _dtype_categories if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') @@ -338,7 +340,6 @@ class Array: def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object: if api_version is not None and not api_version.startswith('2021.'): raise ValueError(f"Unrecognized array API version: {api_version!r}") - from numpy import array_api return array_api def __bool__(self: Array, /) -> bool: -- cgit v1.2.1 From 310929d12967cb0e8e6615466ff9b9f62fc899b6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 20:15:06 -0600 Subject: Fix casting for the array API concat() and stack() --- numpy/array_api/_manipulation_functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index fa6344beb..e68dc6fcf 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -14,10 +14,11 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i See its docstring for more information. """ + # Note: Casting rules here are different from the np.concatenate default + # (no for scalars with axis=None, no cross-kind casting) + dtype = result_type(*arrays) arrays = tuple(a._array for a in arrays) - # Call result type here just to raise on disallowed type combinations - result_type(*arrays) - return Array._new(np.concatenate(arrays, axis=axis)) + return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype)) def expand_dims(x: Array, /, *, axis: int) -> Array: """ @@ -65,7 +66,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> See its docstring for more information. """ - arrays = tuple(a._array for a in arrays) # Call result type here just to raise on disallowed type combinations result_type(*arrays) + arrays = tuple(a._array for a in arrays) return Array._new(np.stack(arrays, axis=axis)) -- cgit v1.2.1 From 3730fc06cd821ec0d7794a6ae058141400921dba Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 5 Aug 2021 17:31:47 -0600 Subject: Give a better error when numpy.array_api is imported in Python 3.7 --- numpy/array_api/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 1e9790a14..08d19744f 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -115,6 +115,12 @@ Still TODO in this module are: """ +import sys +# numpy.array_api is 3.8+ because it makes extensive use of positional-only +# arguments. +if sys.version_info < (3, 8): + raise ImportError("The numpy.array_api submodule requires Python 3.8 or greater.") + import warnings warnings.warn("The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2) -- cgit v1.2.1 From d74a7d0b3297e10194bd3899a0d17a63610e49a1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 5 Aug 2021 20:01:51 -0600 Subject: Add a setup.py to the array_api submodule --- numpy/array_api/setup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 numpy/array_api/setup.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/setup.py b/numpy/array_api/setup.py new file mode 100644 index 000000000..da2350c8f --- /dev/null +++ b/numpy/array_api/setup.py @@ -0,0 +1,10 @@ +def configuration(parent_package='', top_path=None): + from numpy.distutils.misc_util import Configuration + config = Configuration('array_api', parent_package, top_path) + config.add_subpackage('tests') + return config + + +if __name__ == '__main__': + from numpy.distutils.core import setup + setup(configuration=configuration) -- cgit v1.2.1 From fcdadee7815cbb72a1036c0ef144d73e916eae6d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 16:09:23 -0600 Subject: Fix some dictionary key mismatches in the array API tests --- .../array_api/tests/test_elementwise_functions.py | 54 +++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index 2a5ddbc87..ec76cb7a7 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -17,27 +17,27 @@ def test_function_types(): elementwise_function_input_types = { 'abs': 'numeric', - 'acos': 'floating', - 'acosh': 'floating', + 'acos': 'floating-point', + 'acosh': 'floating-point', 'add': 'numeric', - 'asin': 'floating', - 'asinh': 'floating', - 'atan': 'floating', - 'atan2': 'floating', - 'atanh': 'floating', - 'bitwise_and': 'integer_or_boolean', - 'bitwise_invert': 'integer_or_boolean', + 'asin': 'floating-point', + 'asinh': 'floating-point', + 'atan': 'floating-point', + 'atan2': 'floating-point', + 'atanh': 'floating-point', + 'bitwise_and': 'integer or boolean', + 'bitwise_invert': 'integer or boolean', 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer_or_boolean', + 'bitwise_or': 'integer or boolean', 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer_or_boolean', + 'bitwise_xor': 'integer or boolean', 'ceil': 'numeric', - 'cos': 'floating', - 'cosh': 'floating', - 'divide': 'floating', + 'cos': 'floating-point', + 'cosh': 'floating-point', + 'divide': 'floating-point', 'equal': 'all', - 'exp': 'floating', - 'expm1': 'floating', + 'exp': 'floating-point', + 'expm1': 'floating-point', 'floor': 'numeric', 'floor_divide': 'numeric', 'greater': 'numeric', @@ -47,11 +47,11 @@ def test_function_types(): 'isnan': 'numeric', 'less': 'numeric', 'less_equal': 'numeric', - 'log': 'floating', - 'logaddexp': 'floating', - 'log10': 'floating', - 'log1p': 'floating', - 'log2': 'floating', + 'log': 'floating-point', + 'logaddexp': 'floating-point', + 'log10': 'floating-point', + 'log1p': 'floating-point', + 'log2': 'floating-point', 'logical_and': 'boolean', 'logical_not': 'boolean', 'logical_or': 'boolean', @@ -60,17 +60,17 @@ def test_function_types(): 'negative': 'numeric', 'not_equal': 'all', 'positive': 'numeric', - 'pow': 'floating', + 'pow': 'floating-point', 'remainder': 'numeric', 'round': 'numeric', 'sign': 'numeric', - 'sin': 'floating', - 'sinh': 'floating', - 'sqrt': 'floating', + 'sin': 'floating-point', + 'sinh': 'floating-point', + 'sqrt': 'floating-point', 'square': 'numeric', 'subtract': 'numeric', - 'tan': 'floating', - 'tanh': 'floating', + 'tan': 'floating-point', + 'tanh': 'floating-point', 'trunc': 'numeric', } -- cgit v1.2.1 From 2fe8643cce651fa2ada5619f85e3cc16524d4076 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 16:48:16 -0600 Subject: Fix the array API __len__ method --- numpy/array_api/_array_object.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 364b88f89..00f50eade 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -468,8 +468,7 @@ class Array: """ Performs the operation __len__. """ - res = self._array.__len__() - return self.__class__._new(res) + return self._array.__len__() def __lshift__(self: Array, other: Union[int, Array], /) -> Array: """ -- cgit v1.2.1 From 4063752757a97c444b8913947a0890f2c2387bca Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 16:57:10 -0600 Subject: Fix the array API unique() function --- numpy/array_api/_set_functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index f28c2ee72..acd59f597 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -12,4 +12,8 @@ def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = Fal See its docstring for more information. """ - return Array._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) + res = np.unique(x._array, return_counts=return_counts, + return_index=return_index, return_inverse=return_inverse) + if isinstance(res, tuple): + return tuple(Array._new(i) for i in res) + return Array._new(res) -- cgit v1.2.1 From 1ae808401951bf8c4cbff97a30505f08741d811f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 17:12:54 -0600 Subject: Make the axis argument to squeeze() in the array_api module positional-only See data-apis/array-api#100. --- numpy/array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index e68dc6fcf..33f5d5a28 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -52,7 +52,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio """ return Array._new(np.roll(x._array, shift, axis=axis)) -def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ Array API compatible wrapper for :py:func:`np.squeeze `. -- cgit v1.2.1 From f13f08f6c00ff5debf918dd50546b3215e39a5b8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 17:15:39 -0600 Subject: Fix the array API nonzero() function --- numpy/array_api/_searching_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index de5f43f3d..9dcc76b2d 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -29,7 +29,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: See its docstring for more information. """ - return Array._new(np.nonzero(x._array)) + return tuple(Array._new(i) for i in np.nonzero(x._array)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ -- cgit v1.2.1 From 21923a5fa71bfadf7dee0bb5b110cc2a5719eaac Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 18:07:46 -0600 Subject: Update the docstring of numpy.array_api --- numpy/array_api/__init__.py | 79 +++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 38 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 08d19744f..4dc931732 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -1,7 +1,8 @@ """ A NumPy sub-namespace that conforms to the Python array API standard. -This submodule accompanies NEP 47, which proposes its inclusion in NumPy. +This submodule accompanies NEP 47, which proposes its inclusion in NumPy. It +is still considered experimental, and will issue a warning when imported. This is a proof-of-concept namespace that wraps the corresponding NumPy functions to give a conforming implementation of the Python array API standard @@ -38,35 +39,29 @@ A few notes about the current state of this submodule: in progress, but the existing tests pass on this module, with a few exceptions: - - Device support is not yet implemented in NumPy - (https://data-apis.github.io/array-api/latest/design_topics/device_support.html). - As a result, the `device` attribute of the array object is missing, and - array creation functions that take the `device` keyword argument will fail - with NotImplementedError. - - DLPack support (see https://github.com/data-apis/array-api/pull/106) is not included here, as it requires a full implementation in NumPy proper first. - - The linear algebra extension in the spec will be added in a future pull -request. - The test suite is not yet complete, and even the tests that exist are not - guaranteed to give a comprehensive coverage of the spec. Therefore, those - reviewing this submodule should refer to the standard documents themselves. - -- There is a custom array object, numpy.array_api.Array, which is returned - by all functions in this module. All functions in the array API namespace + guaranteed to give a comprehensive coverage of the spec. Therefore, when + reviewing and using this submodule, you should refer to the standard + documents themselves. There are some tests in numpy.array_api.tests, but + they primarily focus on things that are not tested by the official array API + test suite. + +- There is a custom array object, numpy.array_api.Array, which is returned by + all functions in this module. All functions in the array API namespace implicitly assume that they will only receive this object as input. The only way to create instances of this object is to use one of the array creation functions. It does not have a public constructor on the object itself. The - object is a small wrapper Python class around numpy.ndarray. The main - purpose of it is to restrict the namespace of the array object to only those - dtypes and only those methods that are required by the spec, as well as to - limit/change certain behavior that differs in the spec. In particular: + object is a small wrapper class around numpy.ndarray. The main purpose of it + is to restrict the namespace of the array object to only those dtypes and + only those methods that are required by the spec, as well as to limit/change + certain behavior that differs in the spec. In particular: - - The array API namespace does not have scalar objects, only 0-d arrays. - Operations in on Array that would create a scalar in NumPy create a 0-d + - The array API namespace does not have scalar objects, only 0-D arrays. + Operations on Array that would create a scalar in NumPy create a 0-D array. - Indexing: Only a subset of indices supported by NumPy are required by the @@ -76,12 +71,15 @@ request. information. - Type promotion: Some type promotion rules are different in the spec. In - particular, the spec does not have any value-based casting. The - Array._promote_scalar method promotes Python scalars to arrays, - disallowing cross-type promotions like int -> float64 that are not allowed - in the spec. Array._normalize_two_args works around some type promotion - quirks in NumPy, particularly, value-based casting that occurs when one - argument of an operation is a 0-d array. + particular, the spec does not have any value-based casting. The spec also + does not require cross-kind casting, like integer -> floating-point. Only + those promotions that are explicitly required by the array API + specification are allowed in this module. See NEP 47 for more info. + + - Functions do not automatically call asarray() on their input, and will not + work if the input type is not Array. The exception is array creation + functions, and Python operators on the Array object, which accept Python + scalars of the same type as the array dtype. - All functions include type annotations, corresponding to those given in the spec (see _typing.py for definitions of some custom types). These do not @@ -93,26 +91,31 @@ request. equality, but it was considered too much extra complexity to create custom objects to represent dtypes. -- The wrapper functions in this module do not do any type checking for things - that would be impossible without leaving the array_api namespace. For - example, since the array API dtype objects are just the NumPy dtype objects, - one could pass in a non-spec NumPy dtype into a function. - - All places where the implementations in this submodule are known to deviate - from their corresponding functions in NumPy are marked with "# Note" - comments. Reviewers should make note of these comments. + from their corresponding functions in NumPy are marked with "# Note:" + comments. Still TODO in this module are: -- Device support and DLPack support are not yet implemented. These require - support in NumPy itself first. +- DLPack support for numpy.ndarray is still in progress. See + https://github.com/numpy/numpy/pull/19083. -- The a non-default value for the `copy` keyword argument is not yet - implemented on asarray. This requires support in numpy.asarray() first. +- The copy=False keyword argument to asarray() is not yet implemented. This + requires support in numpy.asarray() first. - Some functions are not yet fully tested in the array API test suite, and may require updates that are not yet known until the tests are written. +- The spec is still in an RFC phase and may still have minor updates, which + will need to be reflected here. + +- The linear algebra extension in the spec will be added in a future pull + request. + +- Complex number support in array API spec is planned but not yet finalized, + as are the fft extension and certain linear algebra functions such as eig + that require complex dtypes. + """ import sys -- cgit v1.2.1 From 8f7d00ed447174d9398af3365709222b529c1cad Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 6 Aug 2021 18:22:00 -0600 Subject: Run (selective) black on the array_api submodule I've omitted a few changes from black that messed up the readability of some complicated if statements that were organized logically line-by-line, and some changes that use unnecessary operator spacing. --- numpy/array_api/__init__.py | 247 ++++++++++++++++++--- numpy/array_api/_array_object.py | 205 ++++++++++------- numpy/array_api/_creation_functions.py | 158 ++++++++++--- numpy/array_api/_data_type_functions.py | 16 +- numpy/array_api/_dtypes.py | 75 +++++-- numpy/array_api/_elementwise_functions.py | 194 ++++++++++------ numpy/array_api/_linear_algebra_functions.py | 16 +- numpy/array_api/_manipulation_functions.py | 18 +- numpy/array_api/_searching_functions.py | 4 + numpy/array_api/_set_functions.py | 18 +- numpy/array_api/_sorting_functions.py | 14 +- numpy/array_api/_statistical_functions.py | 65 +++++- numpy/array_api/_typing.py | 30 ++- numpy/array_api/_utility_functions.py | 18 +- numpy/array_api/setup.py | 10 +- numpy/array_api/tests/test_array_object.py | 101 +++++---- numpy/array_api/tests/test_creation_functions.py | 122 ++++++---- .../array_api/tests/test_elementwise_functions.py | 133 ++++++----- 18 files changed, 1054 insertions(+), 390 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 4dc931732..53c1f3850 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -119,36 +119,221 @@ Still TODO in this module are: """ import sys + # numpy.array_api is 3.8+ because it makes extensive use of positional-only # arguments. if sys.version_info < (3, 8): raise ImportError("The numpy.array_api submodule requires Python 3.8 or greater.") import warnings -warnings.warn("The numpy.array_api submodule is still experimental. See NEP 47.", - stacklevel=2) + +warnings.warn( + "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2 +) __all__ = [] from ._constants import e, inf, nan, pi -__all__ += ['e', 'inf', 'nan', 'pi'] - -from ._creation_functions import asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like - -__all__ += ['asarray', 'arange', 'empty', 'empty_like', 'eye', 'from_dlpack', 'full', 'full_like', 'linspace', 'meshgrid', 'ones', 'ones_like', 'zeros', 'zeros_like'] - -from ._data_type_functions import broadcast_arrays, broadcast_to, can_cast, finfo, iinfo, result_type - -__all__ += ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type'] - -from ._dtypes import int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, bool - -__all__ += ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float32', 'float64', 'bool'] - -from ._elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc - -__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logaddexp', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] +__all__ += ["e", "inf", "nan", "pi"] + +from ._creation_functions import ( + asarray, + arange, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + zeros, + zeros_like, +) + +__all__ += [ + "asarray", + "arange", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "zeros", + "zeros_like", +] + +from ._data_type_functions import ( + broadcast_arrays, + broadcast_to, + can_cast, + finfo, + iinfo, + result_type, +) + +__all__ += [ + "broadcast_arrays", + "broadcast_to", + "can_cast", + "finfo", + "iinfo", + "result_type", +] + +from ._dtypes import ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + bool, +) + +__all__ += [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "bool", +] + +from ._elementwise_functions import ( + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_left_shift, + bitwise_invert, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + remainder, + round, + sign, + sin, + sinh, + square, + sqrt, + subtract, + tan, + tanh, + trunc, +) + +__all__ += [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", +] # einsum is not yet implemented in the array API spec. @@ -157,28 +342,36 @@ __all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'at from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot -__all__ += ['matmul', 'tensordot', 'transpose', 'vecdot'] +__all__ += ["matmul", "tensordot", "transpose", "vecdot"] -from ._manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack +from ._manipulation_functions import ( + concat, + expand_dims, + flip, + reshape, + roll, + squeeze, + stack, +) -__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] +__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where -__all__ += ['argmax', 'argmin', 'nonzero', 'where'] +__all__ += ["argmax", "argmin", "nonzero", "where"] from ._set_functions import unique -__all__ += ['unique'] +__all__ += ["unique"] from ._sorting_functions import argsort, sort -__all__ += ['argsort', 'sort'] +__all__ += ["argsort", "sort"] from ._statistical_functions import max, mean, min, prod, std, sum, var -__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] +__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any -__all__ += ['all', 'any'] +__all__ += ["all", "any"] diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 00f50eade..0f511a577 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -18,11 +18,19 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _floating_dtypes, - _numeric_dtypes, _result_type, _dtype_categories) +from ._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _floating_dtypes, + _numeric_dtypes, + _result_type, + _dtype_categories, +) from typing import TYPE_CHECKING, Optional, Tuple, Union + if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -30,6 +38,7 @@ import numpy as np from numpy import array_api + class Array: """ n-d array object for the array API namespace. @@ -45,6 +54,7 @@ class Array: functions, such as asarray(). """ + # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod @@ -64,13 +74,17 @@ class Array: # Convert the array scalar to a 0-D array x = np.asarray(x) if x.dtype not in _all_dtypes: - raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'") + raise TypeError( + f"The array_api namespace does not support the dtype '{x.dtype}'" + ) obj._array = x return obj # Prevent Array() from working def __new__(cls, *args, **kwargs): - raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") + raise TypeError( + "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." + ) # These functions are not required by the spec, but are implemented for # the sake of usability. @@ -79,7 +93,7 @@ class Array: """ Performs the operation __str__. """ - return self._array.__str__().replace('array', 'Array') + return self._array.__str__().replace("array", "Array") def __repr__(self: Array, /) -> str: """ @@ -103,12 +117,12 @@ class Array: """ if self.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: - raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") else: return NotImplemented @@ -116,7 +130,7 @@ class Array: # to promote in the spec (even if the NumPy array operator would # promote them). res_dtype = _result_type(self.dtype, other.dtype) - if op.startswith('__i'): + if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side # operand. For example, @@ -126,7 +140,9 @@ class Array: # The spec explicitly disallows this. if res_dtype != self.dtype: - raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + raise TypeError( + f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}" + ) return other @@ -142,13 +158,19 @@ class Array: """ if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: - raise TypeError("Python bool scalars can only be promoted with bool arrays") + raise TypeError( + "Python bool scalars can only be promoted with bool arrays" + ) elif isinstance(scalar, int): if self.dtype in _boolean_dtypes: - raise TypeError("Python int scalars cannot be promoted with bool arrays") + raise TypeError( + "Python int scalars cannot be promoted with bool arrays" + ) elif isinstance(scalar, float): if self.dtype not in _floating_dtypes: - raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + raise TypeError( + "Python float scalars can only be promoted with floating-point arrays." + ) else: raise TypeError("'scalar' must be a Python scalar") @@ -253,7 +275,9 @@ class Array: except TypeError: return key if not (-size <= key.start <= max(0, size - 1)): - raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + raise IndexError( + "Slices with out-of-bounds start are not allowed in the array API namespace" + ) if key.stop is not None: try: operator.index(key.stop) @@ -269,12 +293,20 @@ class Array: key = tuple(Array._validate_index(idx, None) for idx in key) for idx in key: - if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if ( + isinstance(idx, np.ndarray) + and idx.dtype in _boolean_dtypes + or isinstance(idx, (bool, np.bool_)) + ): if len(key) == 1: return key - raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + raise IndexError( + "Boolean array indices combined with other indices are not allowed in the array API namespace" + ) if isinstance(idx, tuple): - raise IndexError("Nested tuple indices are not allowed in the array API namespace") + raise IndexError( + "Nested tuple indices are not allowed in the array API namespace" + ) if shape is None: return key @@ -283,7 +315,9 @@ class Array: return key ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + for idx, size in list(zip(key[:ellipsis_i], shape)) + list( + zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) + ): Array._validate_index(idx, (size,)) return key elif isinstance(key, bool): @@ -291,18 +325,24 @@ class Array: elif isinstance(key, Array): if key.dtype in _integer_dtypes: if key.ndim != 0: - raise IndexError("Non-zero dimensional integer array indices are not allowed in the array API namespace") + raise IndexError( + "Non-zero dimensional integer array indices are not allowed in the array API namespace" + ) return key._array elif key is Ellipsis: return key elif key is None: - raise IndexError("newaxis indices are not allowed in the array API namespace") + raise IndexError( + "newaxis indices are not allowed in the array API namespace" + ) try: return operator.index(key) except TypeError: # Note: This also omits boolean arrays that are not already in # Array() form, like a list of booleans. - raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + raise IndexError( + "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" + ) # Everything below this line is required by the spec. @@ -311,7 +351,7 @@ class Array: Performs the operation __abs__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __abs__') + raise TypeError("Only numeric dtypes are allowed in __abs__") res = self._array.__abs__() return self.__class__._new(res) @@ -319,7 +359,7 @@ class Array: """ Performs the operation __add__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__add__') + other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -330,15 +370,17 @@ class Array: """ Performs the operation __and__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) - def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object: - if api_version is not None and not api_version.startswith('2021.'): + def __array_namespace__( + self: Array, /, *, api_version: Optional[str] = None + ) -> object: + if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api @@ -373,7 +415,7 @@ class Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, 'all', '__eq__') + other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -394,7 +436,7 @@ class Array: """ Performs the operation __floordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -405,14 +447,20 @@ class Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + other = self._check_allowed_dtypes(other, "numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) - def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array: + def __getitem__( + self: Array, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + /, + ) -> Array: """ Performs the operation __getitem__. """ @@ -426,7 +474,7 @@ class Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + other = self._check_allowed_dtypes(other, "numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -448,7 +496,7 @@ class Array: Performs the operation __invert__. """ if self.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in __invert__') + raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() return self.__class__._new(res) @@ -456,7 +504,7 @@ class Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__le__') + other = self._check_allowed_dtypes(other, "numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -474,7 +522,7 @@ class Array: """ Performs the operation __lshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -485,7 +533,7 @@ class Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + other = self._check_allowed_dtypes(other, "numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -498,7 +546,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + other = self._check_allowed_dtypes(other, "numeric", "__matmul__") if other is NotImplemented: return other res = self._array.__matmul__(other._array) @@ -508,7 +556,7 @@ class Array: """ Performs the operation __mod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + other = self._check_allowed_dtypes(other, "numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -519,7 +567,7 @@ class Array: """ Performs the operation __mul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -530,7 +578,7 @@ class Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, 'all', '__ne__') + other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -542,7 +590,7 @@ class Array: Performs the operation __neg__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __neg__') + raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() return self.__class__._new(res) @@ -550,7 +598,7 @@ class Array: """ Performs the operation __or__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -562,7 +610,7 @@ class Array: Performs the operation __pos__. """ if self.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in __pos__') + raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() return self.__class__._new(res) @@ -574,7 +622,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + other = self._check_allowed_dtypes(other, "floating-point", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -585,14 +633,21 @@ class Array: """ Performs the operation __rshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) - def __setitem__(self, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], value: Union[int, float, bool, Array], /) -> Array: + def __setitem__( + self, + key: Union[ + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + ], + value: Union[int, float, bool, Array], + /, + ) -> Array: """ Performs the operation __setitem__. """ @@ -605,7 +660,7 @@ class Array: """ Performs the operation __sub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -618,7 +673,7 @@ class Array: """ Performs the operation __truediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -629,7 +684,7 @@ class Array: """ Performs the operation __xor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -640,7 +695,7 @@ class Array: """ Performs the operation __iadd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other self._array.__iadd__(other._array) @@ -650,7 +705,7 @@ class Array: """ Performs the operation __radd__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -661,7 +716,7 @@ class Array: """ Performs the operation __iand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other self._array.__iand__(other._array) @@ -671,7 +726,7 @@ class Array: """ Performs the operation __rand__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -682,7 +737,7 @@ class Array: """ Performs the operation __ifloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__") if other is NotImplemented: return other self._array.__ifloordiv__(other._array) @@ -692,7 +747,7 @@ class Array: """ Performs the operation __rfloordiv__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -703,7 +758,7 @@ class Array: """ Performs the operation __ilshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other self._array.__ilshift__(other._array) @@ -713,7 +768,7 @@ class Array: """ Performs the operation __rlshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -728,7 +783,7 @@ class Array: # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other @@ -748,7 +803,7 @@ class Array: """ # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. - other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other res = self._array.__rmatmul__(other._array) @@ -758,7 +813,7 @@ class Array: """ Performs the operation __imod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + other = self._check_allowed_dtypes(other, "numeric", "__imod__") if other is NotImplemented: return other self._array.__imod__(other._array) @@ -768,7 +823,7 @@ class Array: """ Performs the operation __rmod__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + other = self._check_allowed_dtypes(other, "numeric", "__rmod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -779,7 +834,7 @@ class Array: """ Performs the operation __imul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + other = self._check_allowed_dtypes(other, "numeric", "__imul__") if other is NotImplemented: return other self._array.__imul__(other._array) @@ -789,7 +844,7 @@ class Array: """ Performs the operation __rmul__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -800,7 +855,7 @@ class Array: """ Performs the operation __ior__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") if other is NotImplemented: return other self._array.__ior__(other._array) @@ -810,7 +865,7 @@ class Array: """ Performs the operation __ror__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -821,7 +876,7 @@ class Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) @@ -833,7 +888,7 @@ class Array: """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules @@ -844,7 +899,7 @@ class Array: """ Performs the operation __irshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + other = self._check_allowed_dtypes(other, "integer", "__irshift__") if other is NotImplemented: return other self._array.__irshift__(other._array) @@ -854,7 +909,7 @@ class Array: """ Performs the operation __rrshift__. """ - other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -865,7 +920,7 @@ class Array: """ Performs the operation __isub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + other = self._check_allowed_dtypes(other, "numeric", "__isub__") if other is NotImplemented: return other self._array.__isub__(other._array) @@ -875,7 +930,7 @@ class Array: """ Performs the operation __rsub__. """ - other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -886,7 +941,7 @@ class Array: """ Performs the operation __itruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") if other is NotImplemented: return other self._array.__itruediv__(other._array) @@ -896,7 +951,7 @@ class Array: """ Performs the operation __rtruediv__. """ - other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -907,7 +962,7 @@ class Array: """ Performs the operation __ixor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") if other is NotImplemented: return other self._array.__ixor__(other._array) @@ -917,7 +972,7 @@ class Array: """ Performs the operation __rxor__. """ - other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -935,7 +990,7 @@ class Array: @property def device(self) -> Device: - return 'cpu' + return "cpu" @property def ndim(self) -> int: diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index acf78056a..e9c01e7e6 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -2,14 +2,22 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Tuple, Union + if TYPE_CHECKING: - from ._typing import (Array, Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol) + from ._typing import ( + Array, + Device, + Dtype, + NestedSequence, + SupportsDLPack, + SupportsBufferProtocol, + ) from collections.abc import Sequence from ._dtypes import _all_dtypes import numpy as np + def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -20,7 +28,23 @@ def _check_valid_dtype(dtype): return raise ValueError("dtype must be one of the supported dtypes") -def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array: + +def asarray( + obj: Union[ + Array, + bool, + int, + float, + NestedSequence[bool | int | float], + SupportsDLPack, + SupportsBufferProtocol, + ], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. @@ -31,7 +55,7 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") if copy is False: # Note: copy=False is not yet implemented in np.asarray @@ -40,14 +64,23 @@ def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], if copy is True: return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj - if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63): + if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") res = np.asarray(obj, dtype=dtype) return Array._new(res) -def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.arange `. @@ -56,11 +89,17 @@ def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) -def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def empty( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.empty `. @@ -69,11 +108,14 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.empty(shape, dtype=dtype)) -def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def empty_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.empty_like `. @@ -82,11 +124,20 @@ def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.empty_like(x._array, dtype=dtype)) -def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: Optional[int] = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.eye `. @@ -95,15 +146,23 @@ def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + def from_dlpack(x: object, /) -> Array: # Note: dlpack support is not yet implemented on Array raise NotImplementedError("DLPack support is not yet implemented") -def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.full `. @@ -112,7 +171,7 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array @@ -123,7 +182,15 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d raise TypeError("Invalid input to full") return Array._new(res) -def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def full_like( + x: Array, + /, + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.full_like `. @@ -132,7 +199,7 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = np.full_like(x._array, fill_value, dtype=dtype) if res.dtype not in _all_dtypes: @@ -141,7 +208,17 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty raise TypeError("Invalid input to full_like") return Array._new(res) -def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array: + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> Array: """ Array API compatible wrapper for :py:func:`np.linspace `. @@ -150,20 +227,31 @@ def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) -def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]: + +def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. See its docstring for more information. """ from ._array_object import Array - return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)] -def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + return [ + Array._new(array) + for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) + ] + + +def ones( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.ones `. @@ -172,11 +260,14 @@ def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, d from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.ones(shape, dtype=dtype)) -def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def ones_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.ones_like `. @@ -185,11 +276,17 @@ def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[De from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.ones_like(x._array, dtype=dtype)) -def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def zeros( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros `. @@ -198,11 +295,14 @@ def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.zeros(shape, dtype=dtype)) -def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + +def zeros_like( + x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros_like `. @@ -211,6 +311,6 @@ def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D from ._array_object import Array _check_valid_dtype(dtype) - if device not in ['cpu', None]: + if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 17a00cc6d..e6121a8a4 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -5,12 +5,14 @@ from ._dtypes import _all_dtypes, _result_type from dataclasses import dataclass from typing import TYPE_CHECKING, List, Tuple, Union + if TYPE_CHECKING: from ._typing import Dtype from collections.abc import Sequence import numpy as np + def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -18,7 +20,11 @@ def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: See its docstring for more information. """ from ._array_object import Array - return [Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])] + + return [ + Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays]) + ] + def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: """ @@ -27,8 +33,10 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: See its docstring for more information. """ from ._array_object import Array + return Array._new(np.broadcast_to(x._array, shape)) + def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: """ Array API compatible wrapper for :py:func:`np.can_cast `. @@ -36,10 +44,12 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: See its docstring for more information. """ from ._array_object import Array + if isinstance(from_, Array): from_ = from_._array return np.can_cast(from_, to) + # These are internal objects for the return types of finfo and iinfo, since # the NumPy versions contain extra data that isn't part of the spec. @dataclass @@ -55,12 +65,14 @@ class finfo_object: # smallest_normal: float + @dataclass class iinfo_object: bits: int max: int min: int + def finfo(type: Union[Dtype, Array], /) -> finfo_object: """ Array API compatible wrapper for :py:func:`np.finfo `. @@ -79,6 +91,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: # float(fi.smallest_normal), ) + def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: """ Array API compatible wrapper for :py:func:`np.iinfo `. @@ -88,6 +101,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: ii = np.iinfo(type) return iinfo_object(ii.bits, ii.max, ii.min) + def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index 07be267da..476d619fe 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -2,34 +2,66 @@ import numpy as np # Note: we use dtype objects instead of dtype classes. The spec does not # require any behavior on dtypes other than equality. -int8 = np.dtype('int8') -int16 = np.dtype('int16') -int32 = np.dtype('int32') -int64 = np.dtype('int64') -uint8 = np.dtype('uint8') -uint16 = np.dtype('uint16') -uint32 = np.dtype('uint32') -uint64 = np.dtype('uint64') -float32 = np.dtype('float32') -float64 = np.dtype('float64') +int8 = np.dtype("int8") +int16 = np.dtype("int16") +int32 = np.dtype("int32") +int64 = np.dtype("int64") +uint8 = np.dtype("uint8") +uint16 = np.dtype("uint16") +uint32 = np.dtype("uint32") +uint64 = np.dtype("uint64") +float32 = np.dtype("float32") +float64 = np.dtype("float64") # Note: This name is changed -bool = np.dtype('bool') +bool = np.dtype("bool") -_all_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, bool) +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + bool, +) _boolean_dtypes = (bool,) _floating_dtypes = (float32, float64) _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) _dtype_categories = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer or boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating-point': _floating_dtypes, + "all": _all_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "floating-point": _floating_dtypes, } @@ -104,6 +136,7 @@ _promotion_table = { (bool, bool): bool, } + def _result_type(type1, type2): if (type1, type2) in _promotion_table: return _promotion_table[type1, type2] diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 7833ebe54..4408fe833 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -1,12 +1,18 @@ from __future__ import annotations -from ._dtypes import (_boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, _result_type) +from ._dtypes import ( + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + _result_type, +) from ._array_object import Array import numpy as np + def abs(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.abs `. @@ -14,9 +20,10 @@ def abs(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in abs') + raise TypeError("Only numeric dtypes are allowed in abs") return Array._new(np.abs(x._array)) + # Note: the function name is different here def acos(x: Array, /) -> Array: """ @@ -25,9 +32,10 @@ def acos(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in acos') + raise TypeError("Only floating-point dtypes are allowed in acos") return Array._new(np.arccos(x._array)) + # Note: the function name is different here def acosh(x: Array, /) -> Array: """ @@ -36,9 +44,10 @@ def acosh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in acosh') + raise TypeError("Only floating-point dtypes are allowed in acosh") return Array._new(np.arccosh(x._array)) + def add(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.add `. @@ -46,12 +55,13 @@ def add(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in add') + raise TypeError("Only numeric dtypes are allowed in add") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.add(x1._array, x2._array)) + # Note: the function name is different here def asin(x: Array, /) -> Array: """ @@ -60,9 +70,10 @@ def asin(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in asin') + raise TypeError("Only floating-point dtypes are allowed in asin") return Array._new(np.arcsin(x._array)) + # Note: the function name is different here def asinh(x: Array, /) -> Array: """ @@ -71,9 +82,10 @@ def asinh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in asinh') + raise TypeError("Only floating-point dtypes are allowed in asinh") return Array._new(np.arcsinh(x._array)) + # Note: the function name is different here def atan(x: Array, /) -> Array: """ @@ -82,9 +94,10 @@ def atan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atan') + raise TypeError("Only floating-point dtypes are allowed in atan") return Array._new(np.arctan(x._array)) + # Note: the function name is different here def atan2(x1: Array, x2: Array, /) -> Array: """ @@ -93,12 +106,13 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atan2') + raise TypeError("Only floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.arctan2(x1._array, x2._array)) + # Note: the function name is different here def atanh(x: Array, /) -> Array: """ @@ -107,22 +121,27 @@ def atanh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in atanh') + raise TypeError("Only floating-point dtypes are allowed in atanh") return Array._new(np.arctanh(x._array)) + def bitwise_and(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_and `. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_and(x1._array, x2._array)) + # Note: the function name is different here def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: """ @@ -131,15 +150,16 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') + raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): - raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0') + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return Array._new(np.left_shift(x1._array, x2._array)) + # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: """ @@ -148,22 +168,27 @@ def bitwise_invert(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") return Array._new(np.invert(x._array)) + def bitwise_or(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_or `. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_or(x1._array, x2._array)) + # Note: the function name is different here def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: """ @@ -172,28 +197,33 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') + raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): - raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0') + raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") return Array._new(np.right_shift(x1._array, x2._array)) + def bitwise_xor(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.bitwise_xor `. See its docstring for more information. """ - if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: - raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') + if ( + x1.dtype not in _integer_or_boolean_dtypes + or x2.dtype not in _integer_or_boolean_dtypes + ): + raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_xor(x1._array, x2._array)) + def ceil(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.ceil `. @@ -201,12 +231,13 @@ def ceil(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in ceil') + raise TypeError("Only numeric dtypes are allowed in ceil") if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x return Array._new(np.ceil(x._array)) + def cos(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.cos `. @@ -214,9 +245,10 @@ def cos(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in cos') + raise TypeError("Only floating-point dtypes are allowed in cos") return Array._new(np.cos(x._array)) + def cosh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.cosh `. @@ -224,9 +256,10 @@ def cosh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in cosh') + raise TypeError("Only floating-point dtypes are allowed in cosh") return Array._new(np.cosh(x._array)) + def divide(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.divide `. @@ -234,12 +267,13 @@ def divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in divide') + raise TypeError("Only floating-point dtypes are allowed in divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.divide(x1._array, x2._array)) + def equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.equal `. @@ -251,6 +285,7 @@ def equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) + def exp(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.exp `. @@ -258,9 +293,10 @@ def exp(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in exp') + raise TypeError("Only floating-point dtypes are allowed in exp") return Array._new(np.exp(x._array)) + def expm1(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.expm1 `. @@ -268,9 +304,10 @@ def expm1(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in expm1') + raise TypeError("Only floating-point dtypes are allowed in expm1") return Array._new(np.expm1(x._array)) + def floor(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.floor `. @@ -278,12 +315,13 @@ def floor(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in floor') + raise TypeError("Only numeric dtypes are allowed in floor") if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x return Array._new(np.floor(x._array)) + def floor_divide(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.floor_divide `. @@ -291,12 +329,13 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in floor_divide') + raise TypeError("Only numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.floor_divide(x1._array, x2._array)) + def greater(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.greater `. @@ -304,12 +343,13 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in greater') + raise TypeError("Only numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) + def greater_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.greater_equal `. @@ -317,12 +357,13 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in greater_equal') + raise TypeError("Only numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) + def isfinite(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isfinite `. @@ -330,9 +371,10 @@ def isfinite(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isfinite') + raise TypeError("Only numeric dtypes are allowed in isfinite") return Array._new(np.isfinite(x._array)) + def isinf(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isinf `. @@ -340,9 +382,10 @@ def isinf(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isinf') + raise TypeError("Only numeric dtypes are allowed in isinf") return Array._new(np.isinf(x._array)) + def isnan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isnan `. @@ -350,9 +393,10 @@ def isnan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in isnan') + raise TypeError("Only numeric dtypes are allowed in isnan") return Array._new(np.isnan(x._array)) + def less(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.less `. @@ -360,12 +404,13 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in less') + raise TypeError("Only numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) + def less_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.less_equal `. @@ -373,12 +418,13 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in less_equal') + raise TypeError("Only numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) + def log(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log `. @@ -386,9 +432,10 @@ def log(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log') + raise TypeError("Only floating-point dtypes are allowed in log") return Array._new(np.log(x._array)) + def log1p(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log1p `. @@ -396,9 +443,10 @@ def log1p(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log1p') + raise TypeError("Only floating-point dtypes are allowed in log1p") return Array._new(np.log1p(x._array)) + def log2(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log2 `. @@ -406,9 +454,10 @@ def log2(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log2') + raise TypeError("Only floating-point dtypes are allowed in log2") return Array._new(np.log2(x._array)) + def log10(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log10 `. @@ -416,9 +465,10 @@ def log10(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in log10') + raise TypeError("Only floating-point dtypes are allowed in log10") return Array._new(np.log10(x._array)) + def logaddexp(x1: Array, x2: Array) -> Array: """ Array API compatible wrapper for :py:func:`np.logaddexp `. @@ -426,12 +476,13 @@ def logaddexp(x1: Array, x2: Array) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in logaddexp') + raise TypeError("Only floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logaddexp(x1._array, x2._array)) + def logical_and(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_and `. @@ -439,12 +490,13 @@ def logical_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_and') + raise TypeError("Only boolean dtypes are allowed in logical_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_and(x1._array, x2._array)) + def logical_not(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_not `. @@ -452,9 +504,10 @@ def logical_not(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_not') + raise TypeError("Only boolean dtypes are allowed in logical_not") return Array._new(np.logical_not(x._array)) + def logical_or(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_or `. @@ -462,12 +515,13 @@ def logical_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_or') + raise TypeError("Only boolean dtypes are allowed in logical_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_or(x1._array, x2._array)) + def logical_xor(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_xor `. @@ -475,12 +529,13 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError('Only boolean dtypes are allowed in logical_xor') + raise TypeError("Only boolean dtypes are allowed in logical_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) + def multiply(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.multiply `. @@ -488,12 +543,13 @@ def multiply(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in multiply') + raise TypeError("Only numeric dtypes are allowed in multiply") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.multiply(x1._array, x2._array)) + def negative(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.negative `. @@ -501,9 +557,10 @@ def negative(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in negative') + raise TypeError("Only numeric dtypes are allowed in negative") return Array._new(np.negative(x._array)) + def not_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.not_equal `. @@ -515,6 +572,7 @@ def not_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) + def positive(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.positive `. @@ -522,9 +580,10 @@ def positive(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in positive') + raise TypeError("Only numeric dtypes are allowed in positive") return Array._new(np.positive(x._array)) + # Note: the function name is different here def pow(x1: Array, x2: Array, /) -> Array: """ @@ -533,12 +592,13 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in pow') + raise TypeError("Only floating-point dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.power(x1._array, x2._array)) + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder `. @@ -546,12 +606,13 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in remainder') + raise TypeError("Only numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.remainder(x1._array, x2._array)) + def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round `. @@ -559,9 +620,10 @@ def round(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in round') + raise TypeError("Only numeric dtypes are allowed in round") return Array._new(np.round(x._array)) + def sign(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sign `. @@ -569,9 +631,10 @@ def sign(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in sign') + raise TypeError("Only numeric dtypes are allowed in sign") return Array._new(np.sign(x._array)) + def sin(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sin `. @@ -579,9 +642,10 @@ def sin(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sin') + raise TypeError("Only floating-point dtypes are allowed in sin") return Array._new(np.sin(x._array)) + def sinh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sinh `. @@ -589,9 +653,10 @@ def sinh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sinh') + raise TypeError("Only floating-point dtypes are allowed in sinh") return Array._new(np.sinh(x._array)) + def square(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.square `. @@ -599,9 +664,10 @@ def square(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in square') + raise TypeError("Only numeric dtypes are allowed in square") return Array._new(np.square(x._array)) + def sqrt(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sqrt `. @@ -609,9 +675,10 @@ def sqrt(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in sqrt') + raise TypeError("Only floating-point dtypes are allowed in sqrt") return Array._new(np.sqrt(x._array)) + def subtract(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.subtract `. @@ -619,12 +686,13 @@ def subtract(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in subtract') + raise TypeError("Only numeric dtypes are allowed in subtract") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.subtract(x1._array, x2._array)) + def tan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tan `. @@ -632,9 +700,10 @@ def tan(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in tan') + raise TypeError("Only floating-point dtypes are allowed in tan") return Array._new(np.tan(x._array)) + def tanh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tanh `. @@ -642,9 +711,10 @@ def tanh(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in tanh') + raise TypeError("Only floating-point dtypes are allowed in tanh") return Array._new(np.tanh(x._array)) + def trunc(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.trunc `. @@ -652,7 +722,7 @@ def trunc(x: Array, /) -> Array: See its docstring for more information. """ if x.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in trunc') + raise TypeError("Only numeric dtypes are allowed in trunc") if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py index f13f9c541..089081725 100644 --- a/numpy/array_api/_linear_algebra_functions.py +++ b/numpy/array_api/_linear_algebra_functions.py @@ -17,6 +17,7 @@ import numpy as np # """ # return np.einsum() + def matmul(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matmul `. @@ -26,23 +27,31 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: the restriction to numeric dtypes only is different from # np.matmul. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in matmul') + raise TypeError("Only numeric dtypes are allowed in matmul") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) return Array._new(np.matmul(x1._array, x2._array)) + # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, +) -> Array: # Note: the restriction to numeric dtypes only is different from # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in tensordot') + raise TypeError("Only numeric dtypes are allowed in tensordot") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.transpose `. @@ -51,6 +60,7 @@ def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: """ return Array._new(np.transpose(x._array, axes=axes)) + # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: if axis is None: diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index 33f5d5a28..c11866261 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -8,7 +8,9 @@ from typing import List, Optional, Tuple, Union import numpy as np # Note: the function name is different here -def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: +def concat( + arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0 +) -> Array: """ Array API compatible wrapper for :py:func:`np.concatenate `. @@ -20,6 +22,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i arrays = tuple(a._array for a in arrays) return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype)) + def expand_dims(x: Array, /, *, axis: int) -> Array: """ Array API compatible wrapper for :py:func:`np.expand_dims `. @@ -28,6 +31,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: """ return Array._new(np.expand_dims(x._array, axis)) + def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.flip `. @@ -36,6 +40,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.flip(x._array, axis=axis)) + def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. @@ -44,7 +49,14 @@ def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ return Array._new(np.reshape(x._array, shape)) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + +def roll( + x: Array, + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Array: """ Array API compatible wrapper for :py:func:`np.roll `. @@ -52,6 +64,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio """ return Array._new(np.roll(x._array, shift, axis=axis)) + def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ Array API compatible wrapper for :py:func:`np.squeeze `. @@ -60,6 +73,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: """ return Array._new(np.squeeze(x._array, axis=axis)) + def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.stack `. diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index 9dcc76b2d..3dcef61c3 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import numpy as np + def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmax `. @@ -15,6 +16,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmin `. @@ -23,6 +25,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + def nonzero(x: Array, /) -> Tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. @@ -31,6 +34,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: """ return tuple(Array._new(i) for i in np.nonzero(x._array)) + def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index acd59f597..357f238f5 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -6,14 +6,26 @@ from typing import Tuple, Union import numpy as np -def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False) -> Union[Array, Tuple[Array, ...]]: + +def unique( + x: Array, + /, + *, + return_counts: bool = False, + return_index: bool = False, + return_inverse: bool = False, +) -> Union[Array, Tuple[Array, ...]]: """ Array API compatible wrapper for :py:func:`np.unique `. See its docstring for more information. """ - res = np.unique(x._array, return_counts=return_counts, - return_index=return_index, return_inverse=return_inverse) + res = np.unique( + x._array, + return_counts=return_counts, + return_index=return_index, + return_inverse=return_inverse, + ) if isinstance(res, tuple): return tuple(Array._new(i) for i in res) return Array._new(res) diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index a125e0718..9cd49786c 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -4,27 +4,33 @@ from ._array_object import Array import numpy as np -def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + +def argsort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: """ Array API compatible wrapper for :py:func:`np.argsort `. See its docstring for more information. """ # Note: this keyword argument is different, and the default is different. - kind = 'stable' if stable else 'quicksort' + kind = "stable" if stable else "quicksort" res = np.argsort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) return Array._new(res) -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + +def sort( + x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> Array: """ Array API compatible wrapper for :py:func:`np.sort `. See its docstring for more information. """ # Note: this keyword argument is different, and the default is different. - kind = 'stable' if stable else 'quicksort' + kind = "stable" if stable else "quicksort" res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index a606203bc..63790b447 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -6,25 +6,76 @@ from typing import Optional, Tuple, Union import numpy as np -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def max( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) -def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def mean( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def min( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) -def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def prod( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) -def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + +def std( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: # Note: the keyword argument correction is different here return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) -def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def sum( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) -def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: + +def var( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: # Note: the keyword argument correction is different here return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 4ff718205..d530a91ae 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,21 +6,39 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ -__all__ = ['Array', 'Device', 'Dtype', 'SupportsDLPack', - 'SupportsBufferProtocol', 'PyCapsule'] +__all__ = [ + "Array", + "Device", + "Dtype", + "SupportsDLPack", + "SupportsBufferProtocol", + "PyCapsule", +] from typing import Any, Sequence, Type, Union -from . import (Array, int8, int16, int32, int64, uint8, uint16, uint32, - uint64, float32, float64) +from . import ( + Array, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +) # This should really be recursive, but that isn't supported yet. See the # similar comment in numpy/typing/_array_like.py NestedSequence = Sequence[Sequence[Any]] Device = Any -Dtype = Type[Union[[int8, int16, int32, int64, uint8, uint16, - uint32, uint64, float32, float64]]] +Dtype = Type[ + Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]] +] SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any diff --git a/numpy/array_api/_utility_functions.py b/numpy/array_api/_utility_functions.py index f243bfe68..5ecb4bd9f 100644 --- a/numpy/array_api/_utility_functions.py +++ b/numpy/array_api/_utility_functions.py @@ -6,7 +6,14 @@ from typing import Optional, Tuple, Union import numpy as np -def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def all( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: """ Array API compatible wrapper for :py:func:`np.all `. @@ -14,7 +21,14 @@ def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep """ return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) -def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: + +def any( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: """ Array API compatible wrapper for :py:func:`np.any `. diff --git a/numpy/array_api/setup.py b/numpy/array_api/setup.py index da2350c8f..c8bc29102 100644 --- a/numpy/array_api/setup.py +++ b/numpy/array_api/setup.py @@ -1,10 +1,12 @@ -def configuration(parent_package='', top_path=None): +def configuration(parent_package="", top_path=None): from numpy.distutils.misc_util import Configuration - config = Configuration('array_api', parent_package, top_path) - config.add_subpackage('tests') + + config = Configuration("array_api", parent_package, top_path) + config.add_subpackage("tests") return config -if __name__ == '__main__': +if __name__ == "__main__": from numpy.distutils.core import setup + setup(configuration=configuration) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 22078bbee..088e09b9f 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,9 +2,20 @@ from numpy.testing import assert_raises import numpy as np from .. import ones, asarray, result_type -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, int8, int16, int32, int64, uint64) +from .._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + int8, + int16, + int32, + int64, + uint64, +) + def test_validate_index(): # The indexing tests in the official array API test suite test that the @@ -61,28 +72,29 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[None, ...]) assert_raises(IndexError, lambda: a[..., None]) + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise - binary_op_dtypes ={ - '__add__': 'numeric', - '__and__': 'integer_or_boolean', - '__eq__': 'all', - '__floordiv__': 'numeric', - '__ge__': 'numeric', - '__gt__': 'numeric', - '__le__': 'numeric', - '__lshift__': 'integer', - '__lt__': 'numeric', - '__mod__': 'numeric', - '__mul__': 'numeric', - '__ne__': 'all', - '__or__': 'integer_or_boolean', - '__pow__': 'floating', - '__rshift__': 'integer', - '__sub__': 'numeric', - '__truediv__': 'floating', - '__xor__': 'integer_or_boolean', + binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer_or_boolean", + "__eq__": "all", + "__floordiv__": "numeric", + "__ge__": "numeric", + "__gt__": "numeric", + "__le__": "numeric", + "__lshift__": "integer", + "__lt__": "numeric", + "__mod__": "numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer_or_boolean", + "__pow__": "floating", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating", + "__xor__": "integer_or_boolean", } # Recompute each time because of in-place ops @@ -92,15 +104,15 @@ def test_operators(): for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: - yield asarray(1., dtype=d) + yield asarray(1.0, dtype=d) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']: - rop = '__r' + op[2:] - iop = '__i' + op[2:] + if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: + rop = "__r" + op[2:] + iop = "__i" + op[2:] ops += [rop, iop] - for s in [1, 1., False]: + for s in [1, 1.0, False]: for _op in ops: for a in _array_vals(): # Test array op scalar. From the spec, the following combinations @@ -149,7 +161,10 @@ def test_operators(): ): assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure in-place operators only promote to the same dtype as the left operand. - elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype: + elif ( + _op.startswith("__i") + and result_type(x.dtype, y.dtype) != x.dtype + ): assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure only those dtypes that are required for every operator are allowed. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes @@ -165,17 +180,20 @@ def test_operators(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) - unary_op_dtypes ={ - '__abs__': 'numeric', - '__invert__': 'integer_or_boolean', - '__neg__': 'numeric', - '__pos__': 'numeric', + unary_op_dtypes = { + "__abs__": "numeric", + "__invert__": "integer_or_boolean", + "__neg__": "numeric", + "__pos__": "numeric", } for op, dtypes in unary_op_dtypes.items(): for a in _array_vals(): - if (dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes - ): + if ( + dtypes == "numeric" + and a.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" + and a.dtype in _integer_or_boolean_dtypes + ): # Only test for no error getattr(a, op)() else: @@ -192,8 +210,8 @@ def test_operators(): yield ones((4, 4), dtype=d) # Scalars always error - for _op in ['__matmul__', '__rmatmul__', '__imatmul__']: - for s in [1, 1., False]: + for _op in ["__matmul__", "__rmatmul__", "__imatmul__"]: + for s in [1, 1.0, False]: for a in _matmul_array_vals(): if (type(s) in [float, int] and a.dtype in _floating_dtypes or type(s) == int and a.dtype in _integer_dtypes): @@ -235,16 +253,17 @@ def test_operators(): else: x.__imatmul__(y) + def test_python_scalar_construtors(): a = asarray(False) b = asarray(0) - c = asarray(0.) + c = asarray(0.0) assert bool(a) == bool(b) == bool(c) == False assert int(a) == int(b) == int(c) == 0 - assert float(a) == float(b) == float(c) == 0. + assert float(a) == float(b) == float(c) == 0.0 # bool/int/float should only be allowed on 0-D arrays. assert_raises(TypeError, lambda: bool(asarray([False]))) assert_raises(TypeError, lambda: int(asarray([0]))) - assert_raises(TypeError, lambda: float(asarray([0.]))) + assert_raises(TypeError, lambda: float(asarray([0.0]))) diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 654f1d9b3..3cb8865cd 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -2,26 +2,53 @@ from numpy.testing import assert_raises import numpy as np from .. import all -from .._creation_functions import (asarray, arange, empty, empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid, ones, ones_like, zeros, zeros_like) +from .._creation_functions import ( + asarray, + arange, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + zeros, + zeros_like, +) from .._array_object import Array -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes, int8, int16, int32, int64, uint64) +from .._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, + _integer_or_boolean_dtypes, + _numeric_dtypes, + int8, + int16, + int32, + int64, + uint64, +) + def test_asarray_errors(): # Test various protections against incorrect usage assert_raises(TypeError, lambda: Array([1])) - assert_raises(TypeError, lambda: asarray(['a'])) - assert_raises(ValueError, lambda: asarray([1.], dtype=np.float16)) + assert_raises(TypeError, lambda: asarray(["a"])) + assert_raises(ValueError, lambda: asarray([1.0], dtype=np.float16)) assert_raises(OverflowError, lambda: asarray(2**100)) # Preferably this would be OverflowError # assert_raises(OverflowError, lambda: asarray([2**100])) assert_raises(TypeError, lambda: asarray([2**100])) - asarray([1], device='cpu') # Doesn't error - assert_raises(ValueError, lambda: asarray([1], device='gpu')) + asarray([1], device="cpu") # Doesn't error + assert_raises(ValueError, lambda: asarray([1], device="gpu")) assert_raises(ValueError, lambda: asarray([1], dtype=int)) - assert_raises(ValueError, lambda: asarray([1], dtype='i')) + assert_raises(ValueError, lambda: asarray([1], dtype="i")) + def test_asarray_copy(): a = asarray([1]) @@ -36,68 +63,79 @@ def test_asarray_copy(): # assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) + def test_arange_errors(): - arange(1, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: arange(1, device='gpu')) + arange(1, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: arange(1, device="gpu")) assert_raises(ValueError, lambda: arange(1, dtype=int)) - assert_raises(ValueError, lambda: arange(1, dtype='i')) + assert_raises(ValueError, lambda: arange(1, dtype="i")) + def test_empty_errors(): - empty((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: empty((1,), device='gpu')) + empty((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: empty((1,), device="gpu")) assert_raises(ValueError, lambda: empty((1,), dtype=int)) - assert_raises(ValueError, lambda: empty((1,), dtype='i')) + assert_raises(ValueError, lambda: empty((1,), dtype="i")) + def test_empty_like_errors(): - empty_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: empty_like(asarray(1), device='gpu')) + empty_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: empty_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i")) + def test_eye_errors(): - eye(1, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: eye(1, device='gpu')) + eye(1, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: eye(1, device="gpu")) assert_raises(ValueError, lambda: eye(1, dtype=int)) - assert_raises(ValueError, lambda: eye(1, dtype='i')) + assert_raises(ValueError, lambda: eye(1, dtype="i")) + def test_full_errors(): - full((1,), 0, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: full((1,), 0, device='gpu')) + full((1,), 0, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: full((1,), 0, device="gpu")) assert_raises(ValueError, lambda: full((1,), 0, dtype=int)) - assert_raises(ValueError, lambda: full((1,), 0, dtype='i')) + assert_raises(ValueError, lambda: full((1,), 0, dtype="i")) + def test_full_like_errors(): - full_like(asarray(1), 0, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: full_like(asarray(1), 0, device='gpu')) + full_like(asarray(1), 0, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu")) assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int)) - assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype='i')) + assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i")) + def test_linspace_errors(): - linspace(0, 1, 10, device='cpu') # Doesn't error - assert_raises(ValueError, lambda: linspace(0, 1, 10, device='gpu')) + linspace(0, 1, 10, device="cpu") # Doesn't error + assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu")) assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float)) - assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype='f')) + assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f")) + def test_ones_errors(): - ones((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: ones((1,), device='gpu')) + ones((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: ones((1,), device="gpu")) assert_raises(ValueError, lambda: ones((1,), dtype=int)) - assert_raises(ValueError, lambda: ones((1,), dtype='i')) + assert_raises(ValueError, lambda: ones((1,), dtype="i")) + def test_ones_like_errors(): - ones_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: ones_like(asarray(1), device='gpu')) + ones_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: ones_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i")) + def test_zeros_errors(): - zeros((1,), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: zeros((1,), device='gpu')) + zeros((1,), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: zeros((1,), device="gpu")) assert_raises(ValueError, lambda: zeros((1,), dtype=int)) - assert_raises(ValueError, lambda: zeros((1,), dtype='i')) + assert_raises(ValueError, lambda: zeros((1,), dtype="i")) + def test_zeros_like_errors(): - zeros_like(asarray(1), device='cpu') # Doesn't error - assert_raises(ValueError, lambda: zeros_like(asarray(1), device='gpu')) + zeros_like(asarray(1), device="cpu") # Doesn't error + assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) - assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype='i')) + assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i")) diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index ec76cb7a7..a9274aec9 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -4,74 +4,80 @@ from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, - _integer_dtypes) +from .._dtypes import ( + _dtype_categories, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, +) + def nargs(func): return len(getfullargspec(func).args) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. elementwise_function_input_types = { - 'abs': 'numeric', - 'acos': 'floating-point', - 'acosh': 'floating-point', - 'add': 'numeric', - 'asin': 'floating-point', - 'asinh': 'floating-point', - 'atan': 'floating-point', - 'atan2': 'floating-point', - 'atanh': 'floating-point', - 'bitwise_and': 'integer or boolean', - 'bitwise_invert': 'integer or boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer or boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer or boolean', - 'ceil': 'numeric', - 'cos': 'floating-point', - 'cosh': 'floating-point', - 'divide': 'floating-point', - 'equal': 'all', - 'exp': 'floating-point', - 'expm1': 'floating-point', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating-point', - 'logaddexp': 'floating-point', - 'log10': 'floating-point', - 'log1p': 'floating-point', - 'log2': 'floating-point', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'all', - 'positive': 'numeric', - 'pow': 'floating-point', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating-point', - 'sinh': 'floating-point', - 'sqrt': 'floating-point', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating-point', - 'tanh': 'floating-point', - 'trunc': 'numeric', + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "numeric", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "numeric", + "floor_divide": "numeric", + "greater": "numeric", + "greater_equal": "numeric", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "numeric", + "less_equal": "numeric", + "log": "floating-point", + "logaddexp": "floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "floating-point", + "remainder": "numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "numeric", } def _array_vals(): @@ -80,7 +86,7 @@ def test_function_types(): for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: - yield asarray(1., dtype=d) + yield asarray(1.0, dtype=d) for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): @@ -94,7 +100,12 @@ def test_function_types(): if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) + def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) - assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) + assert_raises( + ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + ) + assert_raises( + ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) + ) -- cgit v1.2.1 From d5956c170b07cf26b05c921d810dc387d7e819da Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 12 Aug 2021 15:22:58 -0600 Subject: Fix the return annotation for numpy.array_api.Array.__setitem__ --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 0f511a577..2d746e78b 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -647,7 +647,7 @@ class Array: ], value: Union[int, float, bool, Array], /, - ) -> Array: + ) -> None: """ Performs the operation __setitem__. """ -- cgit v1.2.1 From 90537b5dac1d0c569baa794967b919ae4f6fdcca Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 12 Aug 2021 15:34:45 -0600 Subject: Add smallest_normal to the array API finfo This was blocked on #18536, which has been merged. --- numpy/array_api/_data_type_functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index e6121a8a4..fd92aa250 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -60,10 +60,7 @@ class finfo_object: eps: float max: float min: float - # Note: smallest_normal is part of the array API spec, but cannot be used - # until https://github.com/numpy/numpy/pull/18536 is merged. - - # smallest_normal: float + smallest_normal: float @dataclass @@ -87,8 +84,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: float(fi.eps), float(fi.max), float(fi.min), - # TODO: Uncomment this when #18536 is merged. - # float(fi.smallest_normal), + float(fi.smallest_normal), ) -- cgit v1.2.1 From 9978cc5a1861533abf7c6ed7b0f4e1d732171094 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 23 Aug 2021 15:21:24 -0600 Subject: Remove Python 3.7 checks from the array API code NumPy has dropped Python 3.7, so these are no longer necessary (and they didn't completely work anyway). --- numpy/array_api/__init__.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 53c1f3850..1e1ff242f 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -120,11 +120,6 @@ Still TODO in this module are: import sys -# numpy.array_api is 3.8+ because it makes extensive use of positional-only -# arguments. -if sys.version_info < (3, 8): - raise ImportError("The numpy.array_api submodule requires Python 3.8 or greater.") - import warnings warnings.warn( -- cgit v1.2.1 From 06ec0ec8dadf9e0e9f7518c9817b95f14df9d7be Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 23 Aug 2021 17:46:13 -0600 Subject: Remove an unused import --- numpy/array_api/__init__.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 1e1ff242f..790157504 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -118,8 +118,6 @@ Still TODO in this module are: """ -import sys - import warnings warnings.warn( -- cgit v1.2.1 From e4a11cb29af12b0e67fda3f924585b23cee0b567 Mon Sep 17 00:00:00 2001 From: Sista Seetaram Date: Fri, 3 Sep 2021 00:25:57 +0530 Subject: fixed unhashable instance and potential exception as listed in LGTM#19077 --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index d530a91ae..8b7832d74 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -37,7 +37,7 @@ NestedSequence = Sequence[Sequence[Any]] Device = Any Dtype = Type[ - Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]] + Union[(int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64)] ] SupportsDLPack = Any SupportsBufferProtocol = Any -- cgit v1.2.1 From e4c589054b342dc750cdd09686c0c01ab762679c Mon Sep 17 00:00:00 2001 From: Sista Seetaram Date: Fri, 3 Sep 2021 06:09:38 +0530 Subject: fix unhashable instance and potential exception identified by LGTM --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 8b7832d74..831a108bc 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -37,7 +37,7 @@ NestedSequence = Sequence[Sequence[Any]] Device = Any Dtype = Type[ - Union[(int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64)] + Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] ] SupportsDLPack = Any SupportsBufferProtocol = Any -- cgit v1.2.1 From 45dbdc9d8fd3fa7fbabbddf690f6c892cc24aa1d Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:10:23 +0530 Subject: CopyMode added to np.array_api --- numpy/array_api/_creation_functions.py | 2 +- numpy/array_api/tests/test_creation_functions.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e9c01e7e6..c15d54db1 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -43,7 +43,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[bool] = None, + copy: Optional[Union[bool | np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 3cb8865cd..b0c99cd45 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -56,12 +56,17 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 1) assert all(a[0] == 0) - # Once copy=False is implemented, replace this with - # a = asarray([1]) - # b = asarray(a, copy=False) - # a[0] = 0 - # assert all(b[0] == 0) + a = asarray([1]) + b = asarray(a, copy=np._CopyMode.ALWAYS) + a[0] = 0 + assert all(b[0] == 1) + assert all(a[0] == 0) + a = asarray([1]) + b = asarray(a, copy=np._CopyMode.NEVER) + a[0] = 0 + assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert_raises(NotImplementedError, lambda: asarray(a, copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From a39312cf4eca9dc9573737a01f16fee24efb6c0a Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:15:18 +0530 Subject: fixed linting issues --- numpy/array_api/tests/test_creation_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index b0c99cd45..9d2909e91 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -66,7 +66,8 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) - assert_raises(NotImplementedError, lambda: asarray(a, copy=np._CopyMode.NEVER)) + assert_raises(NotImplementedError, lambda: asarray(a, + copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From c2acd5b25a04783fbbe3ba32426039e4dbe9207e Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:17:58 +0530 Subject: fixed linting issues --- numpy/array_api/tests/test_creation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 9d2909e91..7182209dc 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -66,8 +66,8 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) - assert_raises(NotImplementedError, lambda: asarray(a, - copy=np._CopyMode.NEVER)) + assert_raises(NotImplementedError, + lambda: asarray(a, copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From 56647dd47345a7fd24b4ee8d9d52025fcdc3b9ae Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Sat, 4 Sep 2021 22:33:52 +0530 Subject: Addressed reviews --- numpy/array_api/_creation_functions.py | 6 +++--- numpy/array_api/tests/test_creation_functions.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index c15d54db1..2d6cf4414 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -43,7 +43,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[Union[bool | np._CopyMode]] = None, + copy: Optional[Union[bool, np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. @@ -57,11 +57,11 @@ def asarray( _check_valid_dtype(dtype) if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - if copy is False: + if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): - if copy is True: + if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 7182209dc..2ee23a47b 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -67,7 +67,7 @@ def test_asarray_copy(): assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) assert_raises(NotImplementedError, - lambda: asarray(a, copy=np._CopyMode.NEVER)) + lambda: asarray(a, copy=np._CopyMode.IF_NEEDED)) def test_arange_errors(): -- cgit v1.2.1 From 2d112a98ed7597c4120b31908384ae09b0304659 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 25 Sep 2021 17:34:22 -0500 Subject: ENH: Updates to numpy.array_api (#19937) * Add __index__ to array_api and update __int__, __bool__, and __float__ The spec specifies that they should only work on arrays with corresponding dtypes. __index__ is new in the spec since the initial PR, and works identically to np.array.__index__. * Add the to_device method to the array_api This method is new since #18585. It does nothing in NumPy since NumPy does not support non-CPU devices. * Update transpose methods in the array_api transpose() was renamed to matrix_transpose() and now operates on stacks of matrices. A function to permute dimensions will be added once it is finalized in the spec. The attribute mT was added and the T attribute was updated to only operate on 2-dimensional arrays as per the spec. * Restrict input dtypes in the array API statistical functions * Add the dtype parameter to the array API sum() and prod() * Add the function permute_dims() to the array_api namespace permute_dims() is the replacement for transpose(), which was split into permute_dims() and matrix_transpose(). * Add tril and triu to the array API namespace * Fix the array_api Array.__repr__ to indent the array properly * Make the Device type in the array_api just accept the string "cpu" --- numpy/array_api/__init__.py | 11 ++++++--- numpy/array_api/_array_object.py | 34 ++++++++++++++++++++++++++- numpy/array_api/_creation_functions.py | 30 +++++++++++++++++++++++- numpy/array_api/_linear_algebra_functions.py | 13 +++++------ numpy/array_api/_manipulation_functions.py | 11 +++++++++ numpy/array_api/_statistical_functions.py | 35 +++++++++++++++++++++++++++- numpy/array_api/_typing.py | 4 ++-- numpy/array_api/tests/test_array_object.py | 30 +++++++++++++++++++----- 8 files changed, 147 insertions(+), 21 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 790157504..d8b29057e 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -143,6 +143,8 @@ from ._creation_functions import ( meshgrid, ones, ones_like, + tril, + triu, zeros, zeros_like, ) @@ -160,6 +162,8 @@ __all__ += [ "meshgrid", "ones", "ones_like", + "tril", + "triu", "zeros", "zeros_like", ] @@ -333,21 +337,22 @@ __all__ += [ # from ._linear_algebra_functions import einsum # __all__ += ['einsum'] -from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot +from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot -__all__ += ["matmul", "tensordot", "transpose", "vecdot"] +__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] from ._manipulation_functions import ( concat, expand_dims, flip, + permute_dims, reshape, roll, squeeze, stack, ) -__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 2d746e78b..830319e8c 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -99,7 +99,10 @@ class Array: """ Performs the operation __repr__. """ - return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + prefix = "Array(" + suffix = f", dtype={self.dtype.name})" + mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) + return prefix + mid + suffix # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than @@ -391,6 +394,8 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("bool is only allowed on arrays with 0 dimensions") + if self.dtype not in _boolean_dtypes: + raise ValueError("bool is only allowed on boolean arrays") res = self._array.__bool__() return res @@ -429,6 +434,8 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("float is only allowed on arrays with 0 dimensions") + if self.dtype not in _floating_dtypes: + raise ValueError("float is only allowed on floating-point arrays") res = self._array.__float__() return res @@ -488,9 +495,18 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("int is only allowed on arrays with 0 dimensions") + if self.dtype not in _integer_dtypes: + raise ValueError("int is only allowed on integer arrays") res = self._array.__int__() return res + def __index__(self: Array, /) -> int: + """ + Performs the operation __index__. + """ + res = self._array.__index__() + return res + def __invert__(self: Array, /) -> Array: """ Performs the operation __invert__. @@ -979,6 +995,11 @@ class Array: res = self._array.__rxor__(other._array) return self.__class__._new(res) + def to_device(self: Array, device: Device, /) -> Array: + if device == 'cpu': + return self + raise ValueError(f"Unsupported device {device!r}") + @property def dtype(self) -> Dtype: """ @@ -992,6 +1013,12 @@ class Array: def device(self) -> Device: return "cpu" + # Note: mT is new in array API spec (see matrix_transpose) + @property + def mT(self) -> Array: + from ._linear_algebra_functions import matrix_transpose + return matrix_transpose(self) + @property def ndim(self) -> int: """ @@ -1026,4 +1053,9 @@ class Array: See its docstring for more information. """ + # Note: T only works on 2-dimensional arrays. See the corresponding + # note in the specification: + # https://data-apis.org/array-api/latest/API_specification/array_object.html#t + if self.ndim != 2: + raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") return self._array.T diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e9c01e7e6..9f8136267 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -22,7 +22,7 @@ def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. # We use this instead of "dtype in _all_dtypes" because the dtype objects - # define equality with the sorts of things we want to disallw. + # define equality with the sorts of things we want to disallow. for d in (None,) + _all_dtypes: if dtype is d: return @@ -281,6 +281,34 @@ def ones_like( return Array._new(np.ones_like(x._array, dtype=dtype)) +def tril(x: Array, /, *, k: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tril `. + + See its docstring for more information. + """ + from ._array_object import Array + + if x.ndim < 2: + # Note: Unlike np.tril, x must be at least 2-D + raise ValueError("x must be at least 2-dimensional for tril") + return Array._new(np.tril(x._array, k=k)) + + +def triu(x: Array, /, *, k: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.triu `. + + See its docstring for more information. + """ + from ._array_object import Array + + if x.ndim < 2: + # Note: Unlike np.triu, x must be at least 2-D + raise ValueError("x must be at least 2-dimensional for triu") + return Array._new(np.triu(x._array, k=k)) + + def zeros( shape: Union[int, Tuple[int, ...]], *, diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py index 089081725..7a6c9846c 100644 --- a/numpy/array_api/_linear_algebra_functions.py +++ b/numpy/array_api/_linear_algebra_functions.py @@ -52,13 +52,12 @@ def tensordot( return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) -def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: - """ - Array API compatible wrapper for :py:func:`np.transpose `. - - See its docstring for more information. - """ - return Array._new(np.transpose(x._array, axes=axes)) +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) # Note: vecdot is not in NumPy diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index c11866261..4f2114ff5 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -41,6 +41,17 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> return Array._new(np.flip(x._array, axis=axis)) +# Note: The function name is different here (see also matrix_transpose). +# Unlike transpose(), the axes argument is required. +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: + """ + Array API compatible wrapper for :py:func:`np.transpose `. + + See its docstring for more information. + """ + return Array._new(np.transpose(x._array, axes)) + + def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 63790b447..c5abf9468 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -1,8 +1,17 @@ from __future__ import annotations +from ._dtypes import ( + _floating_dtypes, + _numeric_dtypes, +) from ._array_object import Array +from ._creation_functions import asarray +from ._dtypes import float32, float64 -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union + +if TYPE_CHECKING: + from ._typing import Dtype import numpy as np @@ -14,6 +23,8 @@ def max( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in max") return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) @@ -24,6 +35,8 @@ def mean( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in mean") return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) @@ -34,6 +47,8 @@ def min( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in min") return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) @@ -42,8 +57,15 @@ def prod( /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in prod") + # Note: sum() and prod() always upcast float32 to float64 for dtype=None + # We need to do so here before computing the product to avoid overflow + if dtype is None and x.dtype == float32: + x = asarray(x, dtype=float64) return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) @@ -56,6 +78,8 @@ def std( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in std") return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) @@ -64,8 +88,15 @@ def sum( /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in sum") + # Note: sum() and prod() always upcast float32 to float64 for dtype=None + # We need to do so here before summing to avoid overflow + if dtype is None and x.dtype == float32: + x = asarray(x, dtype=float64) return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) @@ -78,4 +109,6 @@ def var( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in var") return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 831a108bc..5f937a56c 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -15,7 +15,7 @@ __all__ = [ "PyCapsule", ] -from typing import Any, Sequence, Type, Union +from typing import Any, Literal, Sequence, Type, Union from . import ( Array, @@ -35,7 +35,7 @@ from . import ( # similar comment in numpy/typing/_array_like.py NestedSequence = Sequence[Sequence[Any]] -Device = Any +Device = Literal["cpu"] Dtype = Type[ Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] ] diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 088e09b9f..7959f92b4 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -1,3 +1,5 @@ +import operator + from numpy.testing import assert_raises import numpy as np @@ -255,15 +257,31 @@ def test_operators(): def test_python_scalar_construtors(): - a = asarray(False) - b = asarray(0) - c = asarray(0.0) + b = asarray(False) + i = asarray(0) + f = asarray(0.0) - assert bool(a) == bool(b) == bool(c) == False - assert int(a) == int(b) == int(c) == 0 - assert float(a) == float(b) == float(c) == 0.0 + assert bool(b) == False + assert int(i) == 0 + assert float(f) == 0.0 + assert operator.index(i) == 0 # bool/int/float should only be allowed on 0-D arrays. assert_raises(TypeError, lambda: bool(asarray([False]))) assert_raises(TypeError, lambda: int(asarray([0]))) assert_raises(TypeError, lambda: float(asarray([0.0]))) + assert_raises(TypeError, lambda: operator.index(asarray([0]))) + + # bool/int/float should only be allowed on arrays of the corresponding + # dtype + assert_raises(ValueError, lambda: bool(i)) + assert_raises(ValueError, lambda: bool(f)) + + assert_raises(ValueError, lambda: int(b)) + assert_raises(ValueError, lambda: int(f)) + + assert_raises(ValueError, lambda: float(b)) + assert_raises(ValueError, lambda: float(i)) + + assert_raises(TypeError, lambda: operator.index(b)) + assert_raises(TypeError, lambda: operator.index(f)) -- cgit v1.2.1 From a5beccfa3574f4fcb1b6030737b728e65803791f Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:43:16 +0200 Subject: MAINT: Fix invalid parameter types used in `Dtype` --- numpy/array_api/_typing.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5f937a56c..f66279fbc 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -15,10 +15,12 @@ __all__ = [ "PyCapsule", ] -from typing import Any, Literal, Sequence, Type, Union +import sys +from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING -from . import ( - Array, +from . import Array +from numpy import ( + dtype, int8, int16, int32, @@ -36,9 +38,22 @@ from . import ( NestedSequence = Sequence[Sequence[Any]] Device = Literal["cpu"] -Dtype = Type[ - Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] -] +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype + SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any -- cgit v1.2.1 From 4f7e991960c24fc9548f8f3d6d5f8967c2ece84a Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:44:02 +0200 Subject: MAINT: Add a missing subscription slot to `NestedSequence` --- numpy/array_api/_typing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index f66279fbc..4785f5fe3 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -16,7 +16,7 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING +from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar from . import Array from numpy import ( @@ -35,7 +35,8 @@ from numpy import ( # This should really be recursive, but that isn't supported yet. See the # similar comment in numpy/typing/_array_like.py -NestedSequence = Sequence[Sequence[Any]] +_T = TypeVar("_T") +NestedSequence = Sequence[Sequence[_T]] Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): -- cgit v1.2.1 From dc69553ef179b6713d85ec747e6d030dd7087f05 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:44:46 +0200 Subject: MAINT: Remove the `Sequence` encapsulation from variadic arguments --- numpy/array_api/_creation_functions.py | 2 +- numpy/array_api/_data_type_functions.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 9f8136267..8a23cd88e 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -232,7 +232,7 @@ def linspace( return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) -def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]: +def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index fd92aa250..7ccbe9469 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: import numpy as np -def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: return iinfo_object(ii.bits, ii.max, ii.min) -def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. -- cgit v1.2.1 From 4f8f50d5b5e992afaf0ef08773bd88d696683bd3 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:47:22 +0200 Subject: MAINT: Import `Array` from the `_array_object` namespace Changed as `Array` does not live in the main `np.array_api` namespace --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 4785f5fe3..519e8463c 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -18,7 +18,7 @@ __all__ = [ import sys from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar -from . import Array +from ._array_object import Array from numpy import ( dtype, int8, -- cgit v1.2.1 From 9c356d55d78620532ed92987af1721d7ca4034ec Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Wed, 29 Sep 2021 22:52:14 +0200 Subject: Disallow `k=None` for the `eye` function --- numpy/array_api/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 8a23cd88e..e36807468 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -134,7 +134,7 @@ def eye( n_cols: Optional[int] = None, /, *, - k: Optional[int] = 0, + k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> Array: -- cgit v1.2.1 From 4d23ebeb068c8d6ba6edfc11d32ab2af8bb89c74 Mon Sep 17 00:00:00 2001 From: Alessia Marcolini <98marcolini@gmail.com> Date: Fri, 8 Oct 2021 09:49:11 +0000 Subject: MAINT: remove unused imports --- numpy/array_api/tests/test_creation_functions.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 3cb8865cd..7b633eaf1 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -8,30 +8,15 @@ from .._creation_functions import ( empty, empty_like, eye, - from_dlpack, full, full_like, linspace, - meshgrid, ones, ones_like, zeros, zeros_like, ) from .._array_object import Array -from .._dtypes import ( - _all_dtypes, - _boolean_dtypes, - _floating_dtypes, - _integer_dtypes, - _integer_or_boolean_dtypes, - _numeric_dtypes, - int8, - int16, - int32, - int64, - uint64, -) def test_asarray_errors(): -- cgit v1.2.1 From f931a434839222bb00282a432d6d6a0c2c52eb7d Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:10:47 +0200 Subject: ENH: Replace `NestedSequence` with a proper nested sequence protocol --- numpy/array_api/_typing.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 519e8463c..5e980b16f 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ +from __future__ import annotations + __all__ = [ "Array", "Device", @@ -16,7 +18,16 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar +from typing import ( + Any, + Literal, + Sequence, + Type, + Union, + TYPE_CHECKING, + TypeVar, + Protocol, +) from ._array_object import Array from numpy import ( @@ -33,10 +44,11 @@ from numpy import ( float64, ) -# This should really be recursive, but that isn't supported yet. See the -# similar comment in numpy/typing/_array_like.py -_T = TypeVar("_T") -NestedSequence = Sequence[Sequence[_T]] +_T_co = TypeVar("_T_co", covariant=True) + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): -- cgit v1.2.1 From 3952e8f1390629078fdb229236b3b1ce40140c32 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:11:06 +0200 Subject: ENH: Change `SupportsDLPack` into a protocol --- numpy/array_api/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5e980b16f..dfa87b358 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -67,6 +67,8 @@ if TYPE_CHECKING or sys.version_info >= (3, 9): else: Dtype = dtype -SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any + +class SupportsDLPack(Protocol): + def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... -- cgit v1.2.1 From d74bea12d19dd92c9cf07cac35e94d45fb331832 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:13:39 +0200 Subject: MAINT: Replace the `__array_namespace__` return type with `Any` Replace `object` as it cannot be used for expressing the objects in the array namespace. --- numpy/array_api/_array_object.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 830319e8c..ef66c5efd 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -382,7 +382,7 @@ class Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None - ) -> object: + ) -> Any: if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api -- cgit v1.2.1 From ba8fcbe2ba0a26cd52dfa9bf40dd2e945e5b298f Mon Sep 17 00:00:00 2001 From: mattip Date: Tue, 2 Nov 2021 11:30:32 +0200 Subject: change from_dlpack to _dlpack, remove unused header --- numpy/array_api/__init__.py | 4 ++-- numpy/array_api/_creation_functions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index d8b29057e..89f5e9cba 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -136,7 +136,7 @@ from ._creation_functions import ( empty, empty_like, eye, - from_dlpack, + _from_dlpack, full, full_like, linspace, @@ -155,7 +155,7 @@ __all__ += [ "empty", "empty_like", "eye", - "from_dlpack", + "_from_dlpack", "full", "full_like", "linspace", diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e36807468..c3644ac2c 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -151,7 +151,7 @@ def eye( return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) -def from_dlpack(x: object, /) -> Array: +def _from_dlpack(x: object, /) -> Array: # Note: dlpack support is not yet implemented on Array raise NotImplementedError("DLPack support is not yet implemented") -- cgit v1.2.1 From bc087bb2b3104547b097e03e669fb5fa77b16f05 Mon Sep 17 00:00:00 2001 From: Amit Kumar Date: Wed, 3 Nov 2021 11:09:28 +0000 Subject: TST: Add a test for device property in `array_api` namespace (#20271) --- numpy/array_api/tests/test_array_object.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 7959f92b4..fb42cf621 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -285,3 +285,14 @@ def test_python_scalar_construtors(): assert_raises(TypeError, lambda: operator.index(b)) assert_raises(TypeError, lambda: operator.index(f)) + + +def test_device_property(): + a = ones((3, 4)) + assert a.device == 'cpu' + + assert np.array_equal(a.to_device('cpu'), a) + assert_raises(ValueError, lambda: a.to_device('gpu')) + + assert np.array_equal(asarray(a, device='cpu'), a) + assert_raises(ValueError, lambda: asarray(a, device='gpu')) -- cgit v1.2.1 From ff2e2a1e7eea29d925063b13922e096d14331222 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Nov 2021 08:07:23 -0700 Subject: MAINT: A few updates to the array_api (#20066) * Allow casting in the array API asarray() * Restrict multidimensional indexing in the array API namespace The spec has recently been updated to only require multiaxis (i.e., tuple) indices in the case where every axis is indexed, meaning there are either as many indices as axes or the index has an ellipsis. * Fix type promotion for numpy.array_api.where where does value-based promotion for 0-dimensional arrays, so we use the same trick as in the Array operators to avoid this. * Print empty array_api arrays using empty() Printing behavior isn't required by the spec. This is just to make things easier to understand, especially with the array API test suite. * Fix an incorrect slice bounds guard in the array API * Disallow multiple different dtypes in the input to np.array_api.meshgrid * Remove DLPack support from numpy.array_api.asarray() from_dlpack() should be used to create arrays using DLPack. * Remove __len__ from the array API array object * Add astype() to numpy.array_api * Update the unique_* functions in numpy.array_api unique() in the array API was replaced with three separate functions, unique_all(), unique_inverse(), and unique_values(), in order to avoid polymorphic return types. Additionally, it should be noted that these functions to not currently conform to the spec with respect to NaN behavior. The spec requires multiple NaNs to be returned, but np.unique() returns a single NaN. Since this is currently an open issue in NumPy to possibly revert, I have not yet worked around this. See https://github.com/numpy/numpy/issues/20326. * Add the stream argument to the array API to_device method This does nothing in NumPy, and is just present so that the signature is valid according to the spec. * Use the NamedTuple classes for the type signatures * Add unique_counts to the array API namespace * Remove some unused imports * Update the array_api indexing restrictions The "multiaxis indexing must index every axis explicitly or use an ellipsis" was supposed to include any type of index, not just tuple indices. * Use a simpler type annotation for the array API to_device method * Fix a test failure in the array_api submodule The array_api cannot use the NumPy testing functions because array_api arrays do not mix with NumPy arrays, and also NumPy testing functions may use APIs that aren't supported in the array API. * Add dlpack support to the array_api submodule --- numpy/array_api/__init__.py | 10 +-- numpy/array_api/_array_object.py | 48 ++++++++----- numpy/array_api/_creation_functions.py | 19 +++-- numpy/array_api/_data_type_functions.py | 7 ++ numpy/array_api/_searching_functions.py | 1 + numpy/array_api/_set_functions.py | 89 ++++++++++++++++++++---- numpy/array_api/tests/test_array_object.py | 21 +++--- numpy/array_api/tests/test_creation_functions.py | 10 +++ 8 files changed, 156 insertions(+), 49 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 89f5e9cba..36e3f3ed5 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -136,7 +136,7 @@ from ._creation_functions import ( empty, empty_like, eye, - _from_dlpack, + from_dlpack, full, full_like, linspace, @@ -155,7 +155,7 @@ __all__ += [ "empty", "empty_like", "eye", - "_from_dlpack", + "from_dlpack", "full", "full_like", "linspace", @@ -169,6 +169,7 @@ __all__ += [ ] from ._data_type_functions import ( + astype, broadcast_arrays, broadcast_to, can_cast, @@ -178,6 +179,7 @@ from ._data_type_functions import ( ) __all__ += [ + "astype", "broadcast_arrays", "broadcast_to", "can_cast", @@ -358,9 +360,9 @@ from ._searching_functions import argmax, argmin, nonzero, where __all__ += ["argmax", "argmin", "nonzero", "where"] -from ._set_functions import unique +from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values -__all__ += ["unique"] +__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"] from ._sorting_functions import argsort, sort diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index ef66c5efd..dc74bb8c5 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -32,7 +32,7 @@ from ._dtypes import ( from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: - from ._typing import PyCapsule, Device, Dtype + from ._typing import Any, PyCapsule, Device, Dtype import numpy as np @@ -99,9 +99,13 @@ class Array: """ Performs the operation __repr__. """ - prefix = "Array(" suffix = f", dtype={self.dtype.name})" - mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) + if 0 in self.shape: + prefix = "empty(" + mid = str(self.shape) + else: + prefix = "Array(" + mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix # These are various helper functions to make the array behavior match the @@ -244,6 +248,10 @@ class Array: The following cases are allowed by NumPy, but not specified by the array API specification: + - Indices to not include an implicit ellipsis at the end. That is, + every axis of an array must be explicitly indexed or an ellipsis + included. + - The start and stop of a slice may not be out of bounds. In particular, for a slice ``i:j:k`` on an axis of size ``n``, only the following are allowed: @@ -270,6 +278,10 @@ class Array: return key if shape == (): return key + if len(shape) > 1: + raise IndexError( + "Multidimensional arrays must include an index for every axis or use an ellipsis" + ) size = shape[0] # Ensure invalid slice entries are passed through. if key.start is not None: @@ -277,7 +289,7 @@ class Array: operator.index(key.start) except TypeError: return key - if not (-size <= key.start <= max(0, size - 1)): + if not (-size <= key.start <= size): raise IndexError( "Slices with out-of-bounds start are not allowed in the array API namespace" ) @@ -322,6 +334,10 @@ class Array: zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) ): Array._validate_index(idx, (size,)) + if n_ellipsis == 0 and len(key) < len(shape): + raise IndexError( + "Multidimensional arrays must include an index for every axis or use an ellipsis" + ) return key elif isinstance(key, bool): return key @@ -339,7 +355,12 @@ class Array: "newaxis indices are not allowed in the array API namespace" ) try: - return operator.index(key) + key = operator.index(key) + if shape is not None and len(shape) > 1: + raise IndexError( + "Multidimensional arrays must include an index for every axis or use an ellipsis" + ) + return key except TypeError: # Note: This also omits boolean arrays that are not already in # Array() form, like a list of booleans. @@ -403,16 +424,14 @@ class Array: """ Performs the operation __dlpack__. """ - res = self._array.__dlpack__(stream=stream) - return self.__class__._new(res) + return self._array.__dlpack__(stream=stream) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ Performs the operation __dlpack_device__. """ # Note: device support is required for this - res = self._array.__dlpack_device__() - return self.__class__._new(res) + return self._array.__dlpack_device__() def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ @@ -527,13 +546,6 @@ class Array: res = self._array.__le__(other._array) return self.__class__._new(res) - # Note: __len__ may end up being removed from the array API spec. - def __len__(self, /) -> int: - """ - Performs the operation __len__. - """ - return self._array.__len__() - def __lshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __lshift__. @@ -995,7 +1007,9 @@ class Array: res = self._array.__rxor__(other._array) return self.__class__._new(res) - def to_device(self: Array, device: Device, /) -> Array: + def to_device(self: Array, device: Device, /, stream: None = None) -> Array: + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") if device == 'cpu': return self raise ValueError(f"Unsupported device {device!r}") diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index c3644ac2c..23beec444 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol, ) from collections.abc import Sequence @@ -36,7 +35,6 @@ def asarray( int, float, NestedSequence[bool | int | float], - SupportsDLPack, SupportsBufferProtocol, ], /, @@ -60,7 +58,9 @@ def asarray( if copy is False: # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): + if isinstance(obj, Array): + if dtype is not None and obj.dtype != dtype: + copy = True if copy is True: return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj @@ -151,9 +151,10 @@ def eye( return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) -def _from_dlpack(x: object, /) -> Array: - # Note: dlpack support is not yet implemented on Array - raise NotImplementedError("DLPack support is not yet implemented") +def from_dlpack(x: object, /) -> Array: + from ._array_object import Array + + return Array._new(np._from_dlpack(x)) def full( @@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ from ._array_object import Array + # Note: unlike np.meshgrid, only inputs with all the same dtype are + # allowed + + if len({a.dtype for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all have the same dtype") + return [ Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 7ccbe9469..e4d6db61b 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -13,6 +13,13 @@ if TYPE_CHECKING: import numpy as np +# Note: astype is a function, not an array method as in NumPy. +def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: + if not copy and dtype == x.dtype: + return x + return Array._new(x._array.astype(dtype=dtype, copy=copy)) + + def broadcast_arrays(*arrays: Array) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index 3dcef61c3..40f5a4d2e 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.where(condition._array, x1._array, x2._array)) diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index 357f238f5..05ee7e555 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -2,19 +2,82 @@ from __future__ import annotations from ._array_object import Array -from typing import Tuple, Union +from typing import NamedTuple import numpy as np +# Note: np.unique() is split into four functions in the array API: +# unique_all, unique_counts, unique_inverse, and unique_values (this is done +# to remove polymorphic return types). -def unique( - x: Array, - /, - *, - return_counts: bool = False, - return_index: bool = False, - return_inverse: bool = False, -) -> Union[Array, Tuple[Array, ...]]: +# Note: The various unique() functions are supposed to return multiple NaNs. +# This does not match the NumPy behavior, however, this is currently left as a +# TODO in this implementation as this behavior may be reverted in np.unique(). +# See https://github.com/numpy/numpy/issues/20326. + +# Note: The functions here return a namedtuple (np.unique() returns a normal +# tuple). + +class UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array + + +class UniqueCountsResult(NamedTuple): + values: Array + counts: Array + + +class UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array + + +def unique_all(x: Array, /) -> UniqueAllResult: + """ + Array API compatible wrapper for :py:func:`np.unique `. + + See its docstring for more information. + """ + res = np.unique( + x._array, + return_counts=True, + return_index=True, + return_inverse=True, + ) + + return UniqueAllResult(*[Array._new(i) for i in res]) + + +def unique_counts(x: Array, /) -> UniqueCountsResult: + res = np.unique( + x._array, + return_counts=True, + return_index=False, + return_inverse=False, + ) + + return UniqueCountsResult(*[Array._new(i) for i in res]) + + +def unique_inverse(x: Array, /) -> UniqueInverseResult: + """ + Array API compatible wrapper for :py:func:`np.unique `. + + See its docstring for more information. + """ + res = np.unique( + x._array, + return_counts=False, + return_index=False, + return_inverse=True, + ) + return UniqueInverseResult(*[Array._new(i) for i in res]) + + +def unique_values(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -22,10 +85,8 @@ def unique( """ res = np.unique( x._array, - return_counts=return_counts, - return_index=return_index, - return_inverse=return_inverse, + return_counts=False, + return_index=False, + return_inverse=False, ) - if isinstance(res, tuple): - return tuple(Array._new(i) for i in res) return Array._new(res) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index fb42cf621..12479d765 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -3,7 +3,7 @@ import operator from numpy.testing import assert_raises import numpy as np -from .. import ones, asarray, result_type +from .. import ones, asarray, result_type, all, equal from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -39,18 +39,18 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[:-4]) assert_raises(IndexError, lambda: a[:3:-1]) assert_raises(IndexError, lambda: a[:-5:-1]) - assert_raises(IndexError, lambda: a[3:]) + assert_raises(IndexError, lambda: a[4:]) assert_raises(IndexError, lambda: a[-4:]) - assert_raises(IndexError, lambda: a[3::-1]) + assert_raises(IndexError, lambda: a[4::-1]) assert_raises(IndexError, lambda: a[-4::-1]) assert_raises(IndexError, lambda: a[...,:5]) assert_raises(IndexError, lambda: a[...,:-5]) - assert_raises(IndexError, lambda: a[...,:4:-1]) + assert_raises(IndexError, lambda: a[...,:5:-1]) assert_raises(IndexError, lambda: a[...,:-6:-1]) - assert_raises(IndexError, lambda: a[...,4:]) + assert_raises(IndexError, lambda: a[...,5:]) assert_raises(IndexError, lambda: a[...,-5:]) - assert_raises(IndexError, lambda: a[...,4::-1]) + assert_raises(IndexError, lambda: a[...,5::-1]) assert_raises(IndexError, lambda: a[...,-5::-1]) # Boolean indices cannot be part of a larger tuple index @@ -74,6 +74,11 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[None, ...]) assert_raises(IndexError, lambda: a[..., None]) + # Multiaxis indices must contain exactly as many indices as dimensions + assert_raises(IndexError, lambda: a[()]) + assert_raises(IndexError, lambda: a[0,]) + assert_raises(IndexError, lambda: a[0]) + assert_raises(IndexError, lambda: a[:]) def test_operators(): # For every operator, we test that it works for the required type @@ -291,8 +296,8 @@ def test_device_property(): a = ones((3, 4)) assert a.device == 'cpu' - assert np.array_equal(a.to_device('cpu'), a) + assert all(equal(a.to_device('cpu'), a)) assert_raises(ValueError, lambda: a.to_device('gpu')) - assert np.array_equal(asarray(a, device='cpu'), a) + assert all(equal(asarray(a, device='cpu'), a)) assert_raises(ValueError, lambda: asarray(a, device='gpu')) diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 7b633eaf1..ebbb6aab3 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -11,11 +11,13 @@ from .._creation_functions import ( full, full_like, linspace, + meshgrid, ones, ones_like, zeros, zeros_like, ) +from .._dtypes import float32, float64 from .._array_object import Array @@ -124,3 +126,11 @@ def test_zeros_like_errors(): assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu")) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i")) + +def test_meshgrid_dtype_errors(): + # Doesn't raise + meshgrid() + meshgrid(asarray([1.], dtype=float32)) + meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32)) + + assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) -- cgit v1.2.1 From a1813504ad44b70fb139181a9df8465bcb22e24d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sun, 14 Nov 2021 14:35:06 -0700 Subject: ENH: Add the linalg extension to the array_api submodule (#19980) --- numpy/array_api/__init__.py | 12 +- numpy/array_api/_array_object.py | 2 +- numpy/array_api/_linear_algebra_functions.py | 67 ----- numpy/array_api/_statistical_functions.py | 9 +- numpy/array_api/linalg.py | 408 +++++++++++++++++++++++++++ 5 files changed, 419 insertions(+), 79 deletions(-) delete mode 100644 numpy/array_api/_linear_algebra_functions.py create mode 100644 numpy/array_api/linalg.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 36e3f3ed5..bbe2fdce2 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -109,9 +109,6 @@ Still TODO in this module are: - The spec is still in an RFC phase and may still have minor updates, which will need to be reflected here. -- The linear algebra extension in the spec will be added in a future pull - request. - - Complex number support in array API spec is planned but not yet finalized, as are the fft extension and certain linear algebra functions such as eig that require complex dtypes. @@ -334,12 +331,13 @@ __all__ += [ "trunc", ] -# einsum is not yet implemented in the array API spec. +# linalg is an extension in the array API spec, which is a sub-namespace. Only +# a subset of functions in it are imported into the top-level namespace. +from . import linalg -# from ._linear_algebra_functions import einsum -# __all__ += ['einsum'] +__all__ += ["linalg"] -from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot +from .linalg import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index dc74bb8c5..8794c5ea5 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -1030,7 +1030,7 @@ class Array: # Note: mT is new in array API spec (see matrix_transpose) @property def mT(self) -> Array: - from ._linear_algebra_functions import matrix_transpose + from .linalg import matrix_transpose return matrix_transpose(self) @property diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py deleted file mode 100644 index 7a6c9846c..000000000 --- a/numpy/array_api/_linear_algebra_functions.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from ._array_object import Array -from ._dtypes import _numeric_dtypes, _result_type - -from typing import Optional, Sequence, Tuple, Union - -import numpy as np - -# einsum is not yet implemented in the array API spec. - -# def einsum(): -# """ -# Array API compatible wrapper for :py:func:`np.einsum `. -# -# See its docstring for more information. -# """ -# return np.einsum() - - -def matmul(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.matmul `. - - See its docstring for more information. - """ - # Note: the restriction to numeric dtypes only is different from - # np.matmul. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in matmul") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - - return Array._new(np.matmul(x1._array, x2._array)) - - -# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot( - x1: Array, - x2: Array, - /, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, -) -> Array: - # Note: the restriction to numeric dtypes only is different from - # np.tensordot. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in tensordot") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) - - -# Note: this function is new in the array API spec. Unlike transpose, it only -# transposes the last two axes. -def matrix_transpose(x: Array, /) -> Array: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) - - -# Note: vecdot is not in NumPy -def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: - if axis is None: - axis = -1 - return tensordot(x1, x2, axes=((axis,), (axis,))) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index c5abf9468..7bee3f4db 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -93,11 +93,12 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast float32 to float64 for dtype=None - # We need to do so here before summing to avoid overflow + # Note: sum() and prod() always upcast integers to (u)int64 and float32 to + # float64 for dtype=None. `np.sum` does that too for integers, but not for + # float32, so we need to special-case it here if dtype is None and x.dtype == float32: - x = asarray(x, dtype=float64) - return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) + dtype = float64 + return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) def var( diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py new file mode 100644 index 000000000..8d7ba659e --- /dev/null +++ b/numpy/array_api/linalg.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +from ._dtypes import _floating_dtypes, _numeric_dtypes +from ._array_object import Array + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._typing import Literal, Optional, Sequence, Tuple, Union + +from typing import NamedTuple + +import numpy.linalg +import numpy as np + +class EighResult(NamedTuple): + eigenvalues: Array + eigenvectors: Array + +class QRResult(NamedTuple): + Q: Array + R: Array + +class SlogdetResult(NamedTuple): + sign: Array + logabsdet: Array + +class SVDResult(NamedTuple): + U: Array + S: Array + Vh: Array + +# Note: the inclusion of the upper keyword is different from +# np.linalg.cholesky, which does not have it. +def cholesky(x: Array, /, *, upper: bool = False) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.cholesky `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.cholesky. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cholesky') + L = np.linalg.cholesky(x._array) + if upper: + return Array._new(L).mT + return Array._new(L) + +# Note: cross is the numpy top-level namespace, not np.linalg +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + """ + Array API compatible wrapper for :py:func:`np.cross `. + + See its docstring for more information. + """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in cross') + # Note: this is different from np.cross(), which broadcasts + if x1.shape != x2.shape: + raise ValueError('x1 and x2 must have the same shape') + if x1.ndim == 0: + raise ValueError('cross() requires arrays of dimension at least 1') + # Note: this is different from np.cross(), which allows dimension 2 + if x1.shape[axis] != 3: + raise ValueError('cross() dimension must equal 3') + return Array._new(np.cross(x1._array, x2._array, axis=axis)) + +def det(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.det `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.det. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in det') + return Array._new(np.linalg.det(x._array)) + +# Note: diagonal is the numpy top-level namespace, not np.linalg +def diagonal(x: Array, /, *, offset: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.diagonal `. + + See its docstring for more information. + """ + # Note: diagonal always operates on the last two axes, whereas np.diagonal + # operates on the first two axes by default + return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) + + +# Note: the keyword argument name upper is different from np.linalg.eigh +def eigh(x: Array, /) -> EighResult: + """ + Array API compatible wrapper for :py:func:`np.linalg.eigh `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.eigh. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in eigh') + + # Note: the return type here is a namedtuple, which is different from + # np.eigh, which only returns a tuple. + return EighResult(*map(Array._new, np.linalg.eigh(x._array))) + + +# Note: the keyword argument name upper is different from np.linalg.eigvalsh +def eigvalsh(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.eigvalsh `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.eigvalsh. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in eigvalsh') + + return Array._new(np.linalg.eigvalsh(x._array)) + +def inv(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.inv `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.inv. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in inv') + + return Array._new(np.linalg.inv(x._array)) + + +# Note: matmul is the numpy top-level namespace but not in np.linalg +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') + + return Array._new(np.matmul(x1._array, x2._array)) + + +# Note: the name here is different from norm(). The array API norm is split +# into matrix_norm and vector_norm(). + +# The type for ord should be Optional[Union[int, float, Literal[np.inf, +# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point +# literals. +def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.norm `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.norm. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in matrix_norm') + + return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) + + +def matrix_power(x: Array, n: int, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matrix_power `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.matrix_power. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power') + + # np.matrix_power already checks if n is an integer + return Array._new(np.linalg.matrix_power(x._array, n)) + +# Note: the keyword argument name rtol is different from np.linalg.matrix_rank +def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matrix_rank `. + + See its docstring for more information. + """ + # Note: this is different from np.linalg.matrix_rank, which supports 1 + # dimensional arrays. + if x.ndim < 2: + raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") + S = np.linalg.svd(x._array, compute_uv=False) + if rtol is None: + tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps + else: + if isinstance(rtol, Array): + rtol = rtol._array + # Note: this is different from np.linalg.matrix_rank, which does not multiply + # the tolerance by the largest singular value. + tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] + return Array._new(np.count_nonzero(S > tol, axis=-1)) + + +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) + +# Note: outer is the numpy top-level namespace, not np.linalg +def outer(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.outer `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.outer. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in outer') + + # Note: the restriction to only 1-dim arrays is different from np.outer + if x1.ndim != 1 or x2.ndim != 1: + raise ValueError('The input arrays to outer must be 1-dimensional') + + return Array._new(np.outer(x1._array, x2._array)) + +# Note: the keyword argument name rtol is different from np.linalg.pinv +def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.pinv `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.pinv. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in pinv') + + # Note: this is different from np.linalg.pinv, which does not multiply the + # default tolerance by max(M, N). + if rtol is None: + rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps + return Array._new(np.linalg.pinv(x._array, rcond=rtol)) + +def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: + """ + Array API compatible wrapper for :py:func:`np.linalg.qr `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.qr. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in qr') + + # Note: the return type here is a namedtuple, which is different from + # np.linalg.qr, which only returns a tuple. + return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) + +def slogdet(x: Array, /) -> SlogdetResult: + """ + Array API compatible wrapper for :py:func:`np.linalg.slogdet `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.slogdet. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in slogdet') + + # Note: the return type here is a namedtuple, which is different from + # np.linalg.slogdet, which only returns a tuple. + return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array))) + +# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a +# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack +# of matrices. The np.linalg.solve behavior of allowing stacks of both +# matrices and vectors is ambiguous c.f. +# https://github.com/numpy/numpy/issues/15349 and +# https://github.com/data-apis/array-api/issues/285. + +# To workaround this, the below is the code from np.linalg.solve except +# only calling solve1 in the exactly 1D case. +def _solve(a, b): + from ..linalg.linalg import (_makearray, _assert_stacked_2d, + _assert_stacked_square, _commonType, + isComplexType, get_linalg_error_extobj, + _raise_linalgerror_singular) + from ..linalg import _umath_linalg + + a, _ = _makearray(a) + _assert_stacked_2d(a) + _assert_stacked_square(a) + b, wrap = _makearray(b) + t, result_t = _commonType(a, b) + + # This part is different from np.linalg.solve + if b.ndim == 1: + gufunc = _umath_linalg.solve1 + else: + gufunc = _umath_linalg.solve + + # This does nothing currently but is left in because it will be relevant + # when complex dtype support is added to the spec in 2022. + signature = 'DD->D' if isComplexType(t) else 'dd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_singular) + r = gufunc(a, b, signature=signature, extobj=extobj) + + return wrap(r.astype(result_t, copy=False)) + +def solve(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.solve `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.solve. + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in solve') + + return Array._new(_solve(x1._array, x2._array)) + +def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: + """ + Array API compatible wrapper for :py:func:`np.linalg.svd `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.svd. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svd') + + # Note: the return type here is a namedtuple, which is different from + # np.svd, which only returns a tuple. + return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices))) + +# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to +# np.linalg.svd(compute_uv=False). +def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + return Array._new(np.linalg.svd(x._array, compute_uv=False)) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg + +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + +# Note: trace is the numpy top-level namespace, not np.linalg +def trace(x: Array, /, *, offset: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.trace `. + + See its docstring for more information. + """ + # Note: trace always operates on the last two axes, whereas np.trace + # operates on the first two axes by default + return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) + +# Note: vecdot is not in NumPy +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + return tensordot(x1, x2, axes=((axis,), (axis,))) + + +# Note: the name here is different from norm(). The array API norm is split +# into matrix_norm and vector_norm(). + +# The type for ord should be Optional[Union[int, float, Literal[np.inf, +# -np.inf]]] but Literal does not support floating-point literals. +def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: + """ + Array API compatible wrapper for :py:func:`np.linalg.norm `. + + See its docstring for more information. + """ + # Note: the restriction to floating-point dtypes only is different from + # np.linalg.norm. + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in norm') + + a = x._array + if axis is None: + a = a.flatten() + axis = 0 + elif isinstance(axis, tuple): + # Note: The axis argument supports any number of axes, whereas norm() + # only supports a single axis for vector norm. + rest = tuple(i for i in range(a.ndim) if i not in axis) + newshape = axis + rest + a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest])) + axis = 0 + return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)) + + +__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] -- cgit v1.2.1 From eb865fa4683bca9a9d7c840f8addd30e39b62f5b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 Dec 2021 14:07:24 -0700 Subject: BUG: Fix the .T attribute in the array_api namespace Fixes #20498. --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 8794c5ea5..ead061882 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -1072,4 +1072,4 @@ class Array: # https://data-apis.org/array-api/latest/API_specification/array_object.html#t if self.ndim != 2: raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") - return self._array.T + return self.__class__._new(self._array.T) -- cgit v1.2.1 From 18fe695dbc66df660039aca9f76a949de8b9348e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 Dec 2021 17:02:43 -0700 Subject: Add tests for T and mT in array_api --- numpy/array_api/tests/test_array_object.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 12479d765..deab50693 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -4,6 +4,7 @@ from numpy.testing import assert_raises import numpy as np from .. import ones, asarray, result_type, all, equal +from .._array_object import Array from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -301,3 +302,16 @@ def test_device_property(): assert all(equal(asarray(a, device='cpu'), a)) assert_raises(ValueError, lambda: asarray(a, device='gpu')) + +def test_array_properties(): + a = ones((1, 2, 3)) + b = ones((2, 3)) + assert_raises(ValueError, lambda: a.T) + + assert isinstance(b.T, Array) + assert b.T.shape == (3, 2) + + assert isinstance(a.mT, Array) + assert a.mT.shape == (1, 3, 2) + assert isinstance(b.mT, Array) + assert b.mT.shape == (3, 2) -- cgit v1.2.1 From 74a3ee7a8b75bf6dc271c9a1a4b55d2ad9758420 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 Dec 2021 13:59:08 -0700 Subject: ENH: Add __array__ to the array_api Array object This is *NOT* part of the array API spec (so it should not be relied on for portable code). However, without this, np.asarray(np.array_api.Array) produces an object array instead of doing the conversion to a NumPy array as expected. This would work once np.asarray() implements dlpack support, but until then, it seems reasonable to make the conversion work. Note that the reverse, calling np.array_api.asarray(np.array), already works because np.array_api.asarray() is just a wrapper for np.asarray(). --- numpy/array_api/_array_object.py | 11 +++++++++++ numpy/array_api/tests/test_array_object.py | 7 +++++++ 2 files changed, 18 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index ead061882..d322e6ca6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -108,6 +108,17 @@ class Array: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix + # This function is not required by the spec, but we implement it here for + # convenience so that np.asarray(np.array_api.Array) will work. + def __array__(self, dtype=None): + """ + Warning: this method is NOT part of the array API spec. Implementers + of other libraries need not include it, and users should not assume it + will be present in other implementations. + + """ + return np.asarray(self._array, dtype=dtype) + # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than # NumPy behavior diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index deab50693..b980bacca 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -315,3 +315,10 @@ def test_array_properties(): assert a.mT.shape == (1, 3, 2) assert isinstance(b.mT, Array) assert b.mT.shape == (3, 2) + +def test___array__(): + a = ones((2, 3), dtype=int16) + assert np.asarray(a) is a._array + b = np.asarray(a, dtype=np.float64) + assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) + assert b.dtype == np.float64 -- cgit v1.2.1 From 5f21063cc317d92a866c7259a9509f5e5d6189c2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 Dec 2021 17:19:55 -0700 Subject: Add type hints to the numpy.array_api.Array.__array__ signature Thanks @BvB93 --- numpy/array_api/_array_object.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index d322e6ca6..75baf34b0 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: from ._typing import Any, PyCapsule, Device, Dtype + import numpy.typing as npt import numpy as np @@ -110,7 +111,7 @@ class Array: # This function is not required by the spec, but we implement it here for # convenience so that np.asarray(np.array_api.Array) will work. - def __array__(self, dtype=None): + def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]: """ Warning: this method is NOT part of the array API spec. Implementers of other libraries need not include it, and users should not assume it -- cgit v1.2.1 From 19a398ae78c3f35ce3d29d87b35da059558d72b4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 7 Dec 2021 15:35:20 -0700 Subject: BUG: Fix handling of the dtype parameter to numpy.array_api.prod() --- numpy/array_api/_statistical_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 7bee3f4db..5bc831ac2 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -65,8 +65,8 @@ def prod( # Note: sum() and prod() always upcast float32 to float64 for dtype=None # We need to do so here before computing the product to avoid overflow if dtype is None and x.dtype == float32: - x = asarray(x, dtype=float64) - return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) + dtype = float64 + return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) def std( -- cgit v1.2.1 From 781c9463673af478b2799549d70ac6d16e0555a5 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Mon, 3 Jan 2022 22:34:12 +0100 Subject: TYP: change type annotation for `__array_namespace__` to ModuleType This is more precise, we are returning a module here. Type checkers will be able to use this info in the future - see https://github.com/data-apis/array-api/issues/267 --- numpy/array_api/_array_object.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 75baf34b0..391e9f575 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -30,6 +30,7 @@ from ._dtypes import ( ) from typing import TYPE_CHECKING, Optional, Tuple, Union, Any +import types if TYPE_CHECKING: from ._typing import Any, PyCapsule, Device, Dtype @@ -415,7 +416,7 @@ class Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None - ) -> Any: + ) -> types.ModuleType: if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api -- cgit v1.2.1 From 5c04e06dacb714fb8381db003eb0f22c08c91417 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Mon, 3 Jan 2022 23:03:10 +0100 Subject: TYP: add a few type annotations to `numpy.array_api.Array` This fixes the majority of the complaints for `$ mypy numpy/array_api`. The comment indicating that one fix is blocked by lack of support in Mypy for `NotImplemented` is responsible for another several dozen errors. [skip ci] --- numpy/array_api/_array_object.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 75baf34b0..900a0ae1e 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -55,6 +55,7 @@ class Array: functions, such as asarray(). """ + _array: np.ndarray # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @@ -124,6 +125,9 @@ class Array: # spec in places where it either deviates from or is more strict than # NumPy behavior + # NOTE: no valid type annotation possible. E.g `Union[Array, + # NotImplemented]` is forbidden, see https://github.com/python/mypy/issues/363 + # Maybe change returned object to `Literal['NotImplemented']`? def _check_allowed_dtypes(self, other, dtype_category, op): """ Helper function for operators to only allow specific input dtypes @@ -200,7 +204,7 @@ class Array: return Array._new(np.array(scalar, self.dtype)) @staticmethod - def _normalize_two_args(x1, x2): + def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: """ Normalize inputs to two arg functions to fix type promotion rules -- cgit v1.2.1 From 890c0cfe15f0e3325b57e44f9b9a9f78f9797f25 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Tue, 4 Jan 2022 08:17:10 +0100 Subject: TYP: accept review comment on ignoring NotImplemented in type checking Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com> --- numpy/array_api/_array_object.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 900a0ae1e..93ae10603 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -125,10 +125,7 @@ class Array: # spec in places where it either deviates from or is more strict than # NumPy behavior - # NOTE: no valid type annotation possible. E.g `Union[Array, - # NotImplemented]` is forbidden, see https://github.com/python/mypy/issues/363 - # Maybe change returned object to `Literal['NotImplemented']`? - def _check_allowed_dtypes(self, other, dtype_category, op): + def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: """ Helper function for operators to only allow specific input dtypes -- cgit v1.2.1 From e3406ed3ef83f7de0f3419361e85bd8634d0fd2b Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 7 Jan 2022 11:53:09 +0000 Subject: BUG: Allow integer inputs for pow-related functions in `array_api` Updates `xp.power()`, `x.__pow__()`, `x.__ipow()__` and `x.__rpow()__` --- numpy/array_api/_array_object.py | 14 ++++++-------- numpy/array_api/_elementwise_functions.py | 4 ++-- numpy/array_api/tests/test_array_object.py | 2 +- numpy/array_api/tests/test_elementwise_functions.py | 2 +- 4 files changed, 10 insertions(+), 12 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 04d9f8a24..c86cfb3a6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -656,15 +656,13 @@ class Array: res = self._array.__pos__() return self.__class__._new(res) - # PEP 484 requires int to be a subtype of float, but __pow__ should not - # accept int. - def __pow__(self: Array, other: Union[float, Array], /) -> Array: + def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __pow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__pow__") + other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -914,23 +912,23 @@ class Array: res = self._array.__ror__(other._array) return self.__class__._new(res) - def __ipow__(self: Array, other: Union[float, Array], /) -> Array: + def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") + other = self._check_allowed_dtypes(other, "numeric", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) return self - def __rpow__(self: Array, other: Union[float, Array], /) -> Array: + def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rpow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") + other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 4408fe833..c758a0944 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in pow") + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index b980bacca..1fe1dfddf 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -98,7 +98,7 @@ def test_operators(): "__mul__": "numeric", "__ne__": "all", "__or__": "integer_or_boolean", - "__pow__": "floating", + "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", "__truediv__": "floating", diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index a9274aec9..b2fb44e76 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -66,7 +66,7 @@ def test_function_types(): "negative": "numeric", "not_equal": "all", "positive": "numeric", - "pow": "floating-point", + "pow": "numeric", "remainder": "numeric", "round": "numeric", "sign": "numeric", -- cgit v1.2.1 From d7a43dfa91cc1363db64da8915db2b4b6c847b81 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 12 Jan 2022 16:20:33 +0000 Subject: BUG: `array_api.argsort(descending=True)` respects relative sort order (#20788) * BUG: `array_api.argsort(descending=True)` respects relative order * Regression test for stable descending `array_api.argsort()` --- numpy/array_api/_sorting_functions.py | 17 ++++++++++++++--- numpy/array_api/tests/test_sorting_functions.py | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 numpy/array_api/tests/test_sorting_functions.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index 9cd49786c..b2a11872f 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -15,9 +15,20 @@ def argsort( """ # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = np.argsort(x._array, axis=axis, kind=kind) - if descending: - res = np.flip(res, axis=axis) + if not descending: + res = np.argsort(x._array, axis=axis, kind=kind) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of np.argsort(x._array, ...) would not + # respect the relative order like it would in native descending sorts. + res = np.flip( + np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res return Array._new(res) diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py new file mode 100644 index 000000000..9848bbfeb --- /dev/null +++ b/numpy/array_api/tests/test_sorting_functions.py @@ -0,0 +1,23 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "obj, axis, expected", + [ + ([0, 0], -1, [0, 1]), + ([0, 1, 0], -1, [1, 0, 2]), + ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]), + ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]), + ], +) +def test_stable_desc_argsort(obj, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(obj) + out = xp.argsort(x, axis=axis, stable=True, descending=True) + assert xp.all(out == xp.asarray(expected)) -- cgit v1.2.1 From 8b967ff2e70afe3a1fd32e33b36f66f34c259139 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 11 Jan 2022 17:04:39 +0000 Subject: BUG: Return correctly shaped inverse indices in `array_api` Specifically for `xp.unique_all()` and `xp.unique_inverse()` --- numpy/array_api/_set_functions.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index 05ee7e555..db9370f84 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -41,14 +41,21 @@ def unique_all(x: Array, /) -> UniqueAllResult: See its docstring for more information. """ - res = np.unique( + values, indices, inverse_indices, counts = np.unique( x._array, return_counts=True, return_index=True, return_inverse=True, ) - - return UniqueAllResult(*[Array._new(i) for i in res]) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult( + Array._new(values), + Array._new(indices), + Array._new(inverse_indices), + Array._new(counts), + ) def unique_counts(x: Array, /) -> UniqueCountsResult: @@ -68,13 +75,16 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: See its docstring for more information. """ - res = np.unique( + values, inverse_indices = np.unique( x._array, return_counts=False, return_index=False, return_inverse=True, ) - return UniqueInverseResult(*[Array._new(i) for i in res]) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) def unique_values(x: Array, /) -> Array: -- cgit v1.2.1 From ccc1091360639de76fab1c1e2a8b31fda81855fb Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 13 Jan 2022 10:03:49 +0000 Subject: Regression test for inverse indices in `array_api` set functions --- numpy/array_api/tests/test_set_functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 numpy/array_api/tests/test_set_functions.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_set_functions.py b/numpy/array_api/tests/test_set_functions.py new file mode 100644 index 000000000..b8eb65d43 --- /dev/null +++ b/numpy/array_api/tests/test_set_functions.py @@ -0,0 +1,19 @@ +import pytest +from hypothesis import given +from hypothesis.extra.array_api import make_strategies_namespace + +from numpy import array_api as xp + +xps = make_strategies_namespace(xp) + + +@pytest.mark.parametrize("func", [xp.unique_all, xp.unique_inverse]) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes())) +def test_inverse_indices_shape(func, x): + """ + Inverse indices share shape of input array + + See https://github.com/numpy/numpy/issues/20638 + """ + out = func(x) + assert out.inverse_indices.shape == x.shape -- cgit v1.2.1 From 8eac9a4bb5b497ca29ebb852f21169ecfd0191e1 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 24 Jan 2022 10:20:20 +0000 Subject: BUG: Fix `np.array_api.can_cast()` by not relying on `np.can_cast()` --- numpy/array_api/_data_type_functions.py | 14 ++++++++---- numpy/array_api/tests/test_data_type_functions.py | 21 ++++++++++++++++++ numpy/array_api/tests/test_validation.py | 27 +++++++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 numpy/array_api/tests/test_data_type_functions.py create mode 100644 numpy/array_api/tests/test_validation.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index e4d6db61b..1198ff778 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -50,11 +50,17 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: See its docstring for more information. """ - from ._array_object import Array - if isinstance(from_, Array): - from_ = from_._array - return np.can_cast(from_, to) + from_ = from_.dtype + elif from_ not in _all_dtypes: + raise TypeError(f"{from_=}, but should be an array_api array or dtype") + if to not in _all_dtypes: + raise TypeError(f"{to=}, but should be a dtype") + try: + dtype = _result_type(from_, to) + return to == dtype + except TypeError: + return False # These are internal objects for the return types of finfo and iinfo, since diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py new file mode 100644 index 000000000..3f01bb311 --- /dev/null +++ b/numpy/array_api/tests/test_data_type_functions.py @@ -0,0 +1,21 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "from_, to, expected", + [ + (xp.int8, xp.int16, True), + (xp.int16, xp.int8, False), + # np.can_cast has discrepancies with the Array API + # See https://github.com/numpy/numpy/issues/20870 + (xp.bool, xp.int8, False), + (xp.asarray(0, dtype=xp.uint8), xp.int8, False), + ], +) +def test_can_cast(from_, to, expected): + """ + can_cast() returns correct result + """ + assert xp.can_cast(from_, to) == expected diff --git a/numpy/array_api/tests/test_validation.py b/numpy/array_api/tests/test_validation.py new file mode 100644 index 000000000..0dd100d15 --- /dev/null +++ b/numpy/array_api/tests/test_validation.py @@ -0,0 +1,27 @@ +from typing import Callable + +import pytest + +from numpy import array_api as xp + + +def p(func: Callable, *args, **kwargs): + f_sig = ", ".join( + [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()] + ) + id_ = f"{func.__name__}({f_sig})" + return pytest.param(func, args, kwargs, id=id_) + + +@pytest.mark.parametrize( + "func, args, kwargs", + [ + p(xp.can_cast, 42, xp.int8), + p(xp.can_cast, xp.int8, 42), + p(xp.result_type, 42), + ], +) +def test_raises_on_invalid_types(func, args, kwargs): + """Function raises TypeError when passed invalidly-typed inputs""" + with pytest.raises(TypeError): + func(*args, **kwargs) -- cgit v1.2.1 From 995f5464b6c5d8569e159a96c6af106721a4e6d5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 9 Feb 2022 08:35:21 +0000 Subject: Note `np.array_api.can_cast()` does not use `np.can_cast()` --- numpy/array_api/_data_type_functions.py | 5 +++++ numpy/array_api/tests/test_data_type_functions.py | 2 -- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 1198ff778..1fb6062f6 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -56,10 +56,15 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: raise TypeError(f"{from_=}, but should be an array_api array or dtype") if to not in _all_dtypes: raise TypeError(f"{to=}, but should be a dtype") + # Note: We avoid np.can_cast() as it has discrepancies with the array API. + # See https://github.com/numpy/numpy/issues/20870 try: + # We promote `from_` and `to` together. We then check if the promoted + # dtype is `to`, which indicates if `from_` can (up)cast to `to`. dtype = _result_type(from_, to) return to == dtype except TypeError: + # _result_type() raises if the dtypes don't promote together return False diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py index 3f01bb311..efe3d0abd 100644 --- a/numpy/array_api/tests/test_data_type_functions.py +++ b/numpy/array_api/tests/test_data_type_functions.py @@ -8,8 +8,6 @@ from numpy import array_api as xp [ (xp.int8, xp.int16, True), (xp.int16, xp.int8, False), - # np.can_cast has discrepancies with the Array API - # See https://github.com/numpy/numpy/issues/20870 (xp.bool, xp.int8, False), (xp.asarray(0, dtype=xp.uint8), xp.int8, False), ], -- cgit v1.2.1 From 945e0aee51a316abb77284f2c4206f8f1d963480 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Thu, 3 Mar 2022 13:00:28 +0530 Subject: MAINT: make np._from_dlpack public --- numpy/array_api/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 741498ff6..3b014d37b 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -154,7 +154,7 @@ def eye( def from_dlpack(x: object, /) -> Array: from ._array_object import Array - return Array._new(np._from_dlpack(x)) + return Array._new(np.from_dlpack(x)) def full( -- cgit v1.2.1 From 7ed0189477ed4a7d2907d0728b3b8121722d8a2f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 16:12:47 -0600 Subject: Update some Note comments in numpy.array_api --- numpy/array_api/_array_object.py | 5 +++++ numpy/array_api/_data_type_functions.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c86cfb3a6..6cf9ec6f3 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -177,6 +177,8 @@ class Array: integer that is too large to fit in a NumPy integer dtype, or TypeError when the scalar type is incompatible with the dtype of self. """ + # Note: Only Python scalar types that match the array dtype are + # allowed. if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: raise TypeError( @@ -195,6 +197,9 @@ class Array: else: raise TypeError("'scalar' must be a Python scalar") + # Note: scalars are unconditionally cast to the same dtype as the + # array. + # Note: the spec only specifies integer-dtype/int promotion # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 1fb6062f6..7026bd489 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -56,7 +56,8 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: raise TypeError(f"{from_=}, but should be an array_api array or dtype") if to not in _all_dtypes: raise TypeError(f"{to=}, but should be a dtype") - # Note: We avoid np.can_cast() as it has discrepancies with the array API. + # Note: We avoid np.can_cast() as it has discrepancies with the array API, + # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8). # See https://github.com/numpy/numpy/issues/20870 try: # We promote `from_` and `to` together. We then check if the promoted -- cgit v1.2.1 From f306e941ffb0bc49604f5507aa8ea614d933cfd0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 18:12:26 -0600 Subject: Remove some outdated comments --- numpy/array_api/linalg.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index 8d7ba659e..cb04113b6 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) -# Note: the keyword argument name upper is different from np.linalg.eigh def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh `. @@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) -# Note: the keyword argument name upper is different from np.linalg.eigvalsh def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh `. -- cgit v1.2.1 From f375d71ca101db9541b1e70476999b574634556d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 18:12:42 -0600 Subject: Properly restrict the input dtypes for the array_api trace, svdvals, and vecdot --- numpy/array_api/linalg.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index cb04113b6..8398147ba 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) -- cgit v1.2.1 From b5ac835f2e1d1188d2f0ab961cadbb74e1e488d0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 20:32:26 -0600 Subject: Add some missing notes to array_api --- numpy/array_api/_sorting_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index b2a11872f..afbb412f7 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -5,6 +5,7 @@ from ._array_object import Array import numpy as np +# Note: the descending keyword argument is new in this function def argsort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: @@ -31,7 +32,7 @@ def argsort( res = max_i - res return Array._new(res) - +# Note: the descending keyword argument is new in this function def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: -- cgit v1.2.1 From 4457e37b39bfc660beb8be579b282e5acbae1b5f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 20:32:39 -0600 Subject: Fix a type hint in numpy.array_api --- numpy/array_api/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index 8398147ba..f422e1c27 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -384,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. -- cgit v1.2.1 From befef7b26773eddd2b656a3ab87f504e6cc173db Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 6 May 2022 09:27:27 +0000 Subject: API: Allow newaxis indexing for `array_api` arrays (#21377) * TST: Add test checking if newaxis indexing works for `array_api` Also removes previous check against newaxis indexing, which is now outdated * TST, BUG: Allow `None` in `array_api` indexing Introduces test for validating flat indexing when `None` is present * MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api` _validate_index() is now called as self._validate_index(shape), and does not return a key. This rework removes the recursive pattern used. Tests are introduced to cover some edge cases. Additionally, its internal docstring reflects new behaviour, and extends the flat indexing note. * MAINT: `advance` -> `advanced` (integer indexing) Co-authored-by: Aaron Meurer * BUG: array_api arrays use internal arrays from array_api array keys When an array_api array is passed as the key for get/setitem, we access the key's internal np.ndarray array to be used as the key for the internal get/setitem operation. This behaviour was initially removed when `_validate_index()` was reworked. * MAINT: Better flat indexing error message for `array_api` arrays Also better semantics for its prior ellipsis count condition Co-authored-by: Sebastian Berg * MAINT: `array_api` arrays don't special case multi-ellipsis errors This gets handled by NumPy-proper. Co-authored-by: Aaron Meurer Co-authored-by: Sebastian Berg --- numpy/array_api/_array_object.py | 212 ++++++++++++++++------------- numpy/array_api/tests/test_array_object.py | 63 ++++++++- 2 files changed, 176 insertions(+), 99 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 6cf9ec6f3..c4746fad9 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union, Any +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types if TYPE_CHECKING: @@ -243,8 +243,7 @@ class Array: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - @staticmethod - def _validate_index(key, shape): + def _validate_index(self, key): """ Validate an index according to the array API. @@ -257,8 +256,7 @@ class Array: https://data-apis.org/array-api/latest/API_specification/indexing.html for the full list of required indexing behavior - This function either raises IndexError if the index ``key`` is - invalid, or a new key to be used in place of ``key`` in indexing. It + This function raises IndexError if the index ``key`` is invalid. It only raises ``IndexError`` on indices that are not already rejected by NumPy, as NumPy will already raise the appropriate error on such indices. ``shape`` may be None, in which case, only cases that are @@ -269,7 +267,7 @@ class Array: - Indices to not include an implicit ellipsis at the end. That is, every axis of an array must be explicitly indexed or an ellipsis - included. + included. This behaviour is sometimes referred to as flat indexing. - The start and stop of a slice may not be out of bounds. In particular, for a slice ``i:j:k`` on an axis of size ``n``, only the @@ -292,100 +290,122 @@ class Array: ``Array._new`` constructor, not this function. """ - if isinstance(key, slice): - if shape is None: - return key - if shape == (): - return key - if len(shape) > 1: + _key = key if isinstance(key, tuple) else (key,) + for i in _key: + if isinstance(i, bool) or not ( + isinstance(i, SupportsIndex) # i.e. ints + or isinstance(i, slice) + or i == Ellipsis + or i is None + or isinstance(i, Array) + or isinstance(i, np.ndarray) + ): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} has {type(i)=}, but only " + "integers, slices (:), ellipsis (...), newaxis (None), " + "zero-dimensional integer arrays and boolean arrays " + "are specified in the Array API." ) - size = shape[0] - # Ensure invalid slice entries are passed through. - if key.start is not None: - try: - operator.index(key.start) - except TypeError: - return key - if not (-size <= key.start <= size): - raise IndexError( - "Slices with out-of-bounds start are not allowed in the array API namespace" - ) - if key.stop is not None: - try: - operator.index(key.stop) - except TypeError: - return key - step = 1 if key.step is None else key.step - if (step > 0 and not (-size <= key.stop <= size) - or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): - raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") - return key - - elif isinstance(key, tuple): - key = tuple(Array._validate_index(idx, None) for idx in key) - - for idx in key: - if ( - isinstance(idx, np.ndarray) - and idx.dtype in _boolean_dtypes - or isinstance(idx, (bool, np.bool_)) - ): - if len(key) == 1: - return key - raise IndexError( - "Boolean array indices combined with other indices are not allowed in the array API namespace" - ) - if isinstance(idx, tuple): - raise IndexError( - "Nested tuple indices are not allowed in the array API namespace" - ) - - if shape is None: - return key - n_ellipsis = key.count(...) - if n_ellipsis > 1: - return key - ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list( - zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) - ): - Array._validate_index(idx, (size,)) - if n_ellipsis == 0 and len(key) < len(shape): + nonexpanding_key = [] + single_axes = [] + n_ellipsis = 0 + key_has_mask = False + for i in _key: + if i is not None: + nonexpanding_key.append(i) + if isinstance(i, Array) or isinstance(i, np.ndarray): + if i.dtype in _boolean_dtypes: + key_has_mask = True + single_axes.append(i) + else: + # i must not be an array here, to avoid elementwise equals + if i == Ellipsis: + n_ellipsis += 1 + else: + single_axes.append(i) + + n_single_axes = len(single_axes) + if n_ellipsis > 1: + return # handled by ndarray + elif n_ellipsis == 0: + # Note boolean masks must be the sole index, which we check for + # later on. + if not key_has_mask and n_single_axes < self.ndim: raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"{self.ndim=}, but the multi-axes index only specifies " + f"{n_single_axes} dimensions. If this was intentional, " + "add a trailing ellipsis (...) which expands into as many " + "slices (:) as necessary - this is what np.ndarray arrays " + "implicitly do, but such flat indexing behaviour is not " + "specified in the Array API." ) - return key - elif isinstance(key, bool): - return key - elif isinstance(key, Array): - if key.dtype in _integer_dtypes: - if key.ndim != 0: + + if n_ellipsis == 0: + indexed_shape = self.shape + else: + ellipsis_start = None + for pos, i in enumerate(nonexpanding_key): + if not (isinstance(i, Array) or isinstance(i, np.ndarray)): + if i == Ellipsis: + ellipsis_start = pos + break + assert ellipsis_start is not None # sanity check + ellipsis_end = self.ndim - (n_single_axes - ellipsis_start) + indexed_shape = ( + self.shape[:ellipsis_start] + self.shape[ellipsis_end:] + ) + for i, side in zip(single_axes, indexed_shape): + if isinstance(i, slice): + if side == 0: + f_range = "0 (or None)" + else: + f_range = f"between -{side} and {side - 1} (or None)" + if i.start is not None: + try: + start = operator.index(i.start) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= start <= side): + raise IndexError( + f"Slice {i} contains {start=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds starts are not specified in " + "the Array API)" + ) + if i.stop is not None: + try: + stop = operator.index(i.stop) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= stop <= side): + raise IndexError( + f"Slice {i} contains {stop=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds stops are not specified in " + "the Array API)" + ) + elif isinstance(i, Array): + if i.dtype in _boolean_dtypes and len(_key) != 1: + assert isinstance(key, tuple) # sanity check raise IndexError( - "Non-zero dimensional integer array indices are not allowed in the array API namespace" + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." ) - return key._array - elif key is Ellipsis: - return key - elif key is None: - raise IndexError( - "newaxis indices are not allowed in the array API namespace" - ) - try: - key = operator.index(key) - if shape is not None and len(shape) > 1: + elif i.dtype in _integer_dtypes and i.ndim != 0: + raise IndexError( + f"Single-axes index {i} is a non-zero-dimensional " + "integer array, but advanced integer indexing is not " + "specified in the Array API." + ) + elif isinstance(i, tuple): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} is a tuple, but nested tuple " + "indices are not specified in the Array API." ) - return key - except TypeError: - # Note: This also omits boolean arrays that are not already in - # Array() form, like a list of booleans. - raise IndexError( - "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" - ) # Everything below this line is required by the spec. @@ -511,7 +531,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array res = self._array.__getitem__(key) return self._new(res) @@ -698,7 +721,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array self._array.__setitem__(key, asarray(value)._array) def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 1fe1dfddf..ba9223532 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,8 +2,9 @@ import operator from numpy.testing import assert_raises import numpy as np +import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, asarray, reshape, result_type, all, equal from .._array_object import Array from .._dtypes import ( _all_dtypes, @@ -17,6 +18,7 @@ from .._dtypes import ( int32, int64, uint64, + bool as bool_, ) @@ -70,11 +72,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) - # np.newaxis is not allowed - assert_raises(IndexError, lambda: a[None]) - assert_raises(IndexError, lambda: a[None, ...]) - assert_raises(IndexError, lambda: a[..., None]) - # Multiaxis indices must contain exactly as many indices as dimensions assert_raises(IndexError, lambda: a[()]) assert_raises(IndexError, lambda: a[0,]) @@ -322,3 +319,57 @@ def test___array__(): b = np.asarray(a, dtype=np.float64) assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 + +def test_allow_newaxis(): + a = ones(5) + indexed_a = a[None, :] + assert indexed_a.shape == (1, 5) + +def test_disallow_flat_indexing_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, 0, 0] + +def test_disallow_mask_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, asarray(True)] + +@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) +@pytest.mark.parametrize("index", ["string", False, True]) +def test_error_on_invalid_index(shape, index): + a = ones(shape) + with pytest.raises(IndexError): + a[index] + +def test_mask_0d_array_without_errors(): + a = ones(()) + a[asarray(True)] + +@pytest.mark.parametrize( + "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])] +) +def test_error_on_invalid_index_with_ellipsis(i): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] + +def test_array_keys_use_private_array(): + """ + Indexing operations convert array keys before indexing the internal array + + Fails when array_api array keys are not converted into NumPy-proper arrays + in __getitem__(). This is achieved by passing array_api arrays with 0-sized + dimensions, which NumPy-proper treats erroneously - not sure why! + + TODO: Find and use appropiate __setitem__() case. + """ + a = ones((0, 0), dtype=bool_) + assert a[a].shape == (0,) + + a = ones((0,), dtype=bool_) + key = ones((0, 0), dtype=bool_) + with pytest.raises(IndexError): + a[key] -- cgit v1.2.1 From 12f83eb7337b840e3cd9026779b99b1af8033bf3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 Jun 2022 15:21:19 -0600 Subject: Fix the array API unique_*() functions to not compare nans as equal The spec requires this, but it is only now possible to implement with the new equal_nan flag in np.unique(). --- numpy/array_api/_set_functions.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index db9370f84..0b4132cf8 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -46,6 +46,7 @@ def unique_all(x: Array, /) -> UniqueAllResult: return_counts=True, return_index=True, return_inverse=True, + equal_nan=False, ) # np.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 @@ -64,6 +65,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: return_counts=True, return_index=False, return_inverse=False, + equal_nan=False, ) return UniqueCountsResult(*[Array._new(i) for i in res]) @@ -80,6 +82,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: return_counts=False, return_index=False, return_inverse=True, + equal_nan=False, ) # np.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 @@ -98,5 +101,6 @@ def unique_values(x: Array, /) -> Array: return_counts=False, return_index=False, return_inverse=False, + equal_nan=False, ) return Array._new(res) -- cgit v1.2.1 From 70026c4dde47d89d6a7a4916bfac045e714a5b4f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 15 Jun 2022 18:44:15 -0600 Subject: MAINT: Fix computation of numpy.array_api.linalg.vector_norm (#21084) * Fix computation of numpy.array_api.linalg.vector_norm Various pieces were incorrect due to a lack of complete coverage of this function in the array API test suite. * Fix the output dtype nonstandard vector norm() Previously it would always give float64 because an internal calculation involved a NumPy scalar and a Python float. The fix is to use a 0-D array instead of a NumPy scalar so that it type promotes with the float correctly. Fixes #21083 I don't have a test for this yet because I'm unclear how exactly to test it. * Clean up the numpy.array_api.linalg.vector_norm code a little bit --- numpy/array_api/linalg.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index f422e1c27..a4a2f23e4 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -1,8 +1,11 @@ from __future__ import annotations from ._dtypes import _floating_dtypes, _numeric_dtypes +from ._manipulation_functions import reshape from ._array_object import Array +from ..core.numeric import normalize_axis_tuple + from typing import TYPE_CHECKING if TYPE_CHECKING: from ._typing import Literal, Optional, Sequence, Tuple, Union @@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in norm') + # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or + # when axis=None and the input is 2-D, so to force a vector norm, we make + # it so the input is 1-D (for axis=None), or reshape so that norm is done + # on a single dimension. a = x._array if axis is None: - a = a.flatten() - axis = 0 + # Note: np.linalg.norm() doesn't handle 0-D arrays + a = a.ravel() + _axis = 0 elif isinstance(axis, tuple): - # Note: The axis argument supports any number of axes, whereas norm() - # only supports a single axis for vector norm. - rest = tuple(i for i in range(a.ndim) if i not in axis) + # Note: The axis argument supports any number of axes, whereas + # np.linalg.norm() only supports a single axis for vector norm. + normalized_axis = normalize_axis_tuple(axis, x.ndim) + rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) newshape = axis + rest - a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest])) - axis = 0 - return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)) + a = np.transpose(a, newshape).reshape( + (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) + _axis = 0 + else: + _axis = axis + + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) + + if keepdims: + # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks + # above to avoid matrix norm logic. + shape = list(x.shape) + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in _axis: + shape[i] = 1 + res = reshape(res, tuple(shape)) + return res __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] -- cgit v1.2.1 From 81b97607339ac68b27cf72ba7923345d58e2895e Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Sat, 16 Jul 2022 15:27:38 -0500 Subject: Add unit testing. --- numpy/array_api/tests/test_asarray.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 numpy/array_api/tests/test_asarray.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py new file mode 100644 index 000000000..d9bbc675b --- /dev/null +++ b/numpy/array_api/tests/test_asarray.py @@ -0,0 +1,24 @@ +import numpy as np + + +def test_fast_return(): + """""" + a = np.array([1, 2, 3], dtype='i') + assert np.asarray(a) is a + assert np.asarray(a, dtype='i') is a + # This may produce a new view or a copy, but is never the same object. + assert np.asarray(a, dtype='l') is not a + + unequal_type = np.dtype('i', metadata={'spam': True}) + b = np.asarray(a, dtype=unequal_type) + assert b is not a + assert b.base is a + + equivalent_requirement = np.dtype('i', metadata={'spam': True}) + c = np.asarray(b, dtype=equivalent_requirement) + # A quirk of the metadata test is that equivalent metadata dicts are still + # separate objects and so don't evaluate as the same array type description. + assert unequal_type == equivalent_requirement + assert unequal_type is not equivalent_requirement + assert c is not b + assert c.dtype is equivalent_requirement -- cgit v1.2.1 From 5651445944bce163a2c3f746d6ac1acd9ae76032 Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Sat, 16 Jul 2022 15:51:51 -0500 Subject: Update comment and obey formatting requirements. --- numpy/array_api/tests/test_asarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py index d9bbc675b..bd1859d71 100644 --- a/numpy/array_api/tests/test_asarray.py +++ b/numpy/array_api/tests/test_asarray.py @@ -16,8 +16,8 @@ def test_fast_return(): equivalent_requirement = np.dtype('i', metadata={'spam': True}) c = np.asarray(b, dtype=equivalent_requirement) - # A quirk of the metadata test is that equivalent metadata dicts are still - # separate objects and so don't evaluate as the same array type description. + # The descriptors are equivalent, but we have created + # distinct dtype instances. assert unequal_type == equivalent_requirement assert unequal_type is not equivalent_requirement assert c is not b -- cgit v1.2.1 From e286f461b54c43e2a16b3f6fc6d829936ea28c27 Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Sun, 17 Jul 2022 10:49:56 -0500 Subject: Expand test_asarray.py. * Improve comments/docs. * Improve descriptiveness of variable names. * Add additional test expressions that would not pass without this patch. --- numpy/array_api/tests/test_asarray.py | 46 +++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 13 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py index bd1859d71..fdcc77506 100644 --- a/numpy/array_api/tests/test_asarray.py +++ b/numpy/array_api/tests/test_asarray.py @@ -1,24 +1,44 @@ import numpy as np -def test_fast_return(): - """""" - a = np.array([1, 2, 3], dtype='i') - assert np.asarray(a) is a - assert np.asarray(a, dtype='i') is a - # This may produce a new view or a copy, but is never the same object. - assert np.asarray(a, dtype='l') is not a +def test_dtype_identity(): + """Confirm the intended behavior for ``asarray`` results. + The result of ``asarray()`` should have the dtype provided through the + keyword argument, when used. This forces unique array handles to be + produced for unique np.dtype objects, but (for equivalent dtypes), the + underlying data (the base object) is shared with the original array object. + + Ref https://github.com/numpy/numpy/issues/1468 + """ + int_array = np.array([1, 2, 3], dtype='i') + assert np.asarray(int_array) is int_array + + # The character code resolves to the singleton dtype object provided + # by the numpy package. + assert np.asarray(int_array, dtype='i') is int_array + + # Derive a dtype from n.dtype('i'), but add a metadata object to force + # the dtype to be distinct. unequal_type = np.dtype('i', metadata={'spam': True}) - b = np.asarray(a, dtype=unequal_type) - assert b is not a - assert b.base is a + annotated_int_array = np.asarray(int_array, dtype=unequal_type) + assert annotated_int_array is not int_array + assert annotated_int_array.base is int_array + + # These ``asarray()`` calls may produce a new view or a copy, + # but never the same object. + long_int_array = np.asarray(int_array, dtype='l') + assert long_int_array is not int_array + assert np.asarray(int_array, dtype='q') is not int_array + assert np.asarray(long_int_array, dtype='q') is not long_int_array + assert np.asarray(int_array, dtype='l') is not np.asarray(int_array, dtype='l') + assert np.asarray(int_array, dtype='l').base is np.asarray(int_array, dtype='l').base equivalent_requirement = np.dtype('i', metadata={'spam': True}) - c = np.asarray(b, dtype=equivalent_requirement) + annotated_int_array_alt = np.asarray(annotated_int_array, dtype=equivalent_requirement) # The descriptors are equivalent, but we have created # distinct dtype instances. assert unequal_type == equivalent_requirement assert unequal_type is not equivalent_requirement - assert c is not b - assert c.dtype is equivalent_requirement + assert annotated_int_array_alt is not annotated_int_array + assert annotated_int_array_alt.dtype is equivalent_requirement -- cgit v1.2.1 From 01438a848b029b4fb3d3509c7fd313bc0588bd38 Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Sun, 17 Jul 2022 10:55:59 -0500 Subject: Lint. Shorten some lines. --- numpy/array_api/tests/test_asarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py index fdcc77506..4a9dd77a0 100644 --- a/numpy/array_api/tests/test_asarray.py +++ b/numpy/array_api/tests/test_asarray.py @@ -31,11 +31,12 @@ def test_dtype_identity(): assert long_int_array is not int_array assert np.asarray(int_array, dtype='q') is not int_array assert np.asarray(long_int_array, dtype='q') is not long_int_array - assert np.asarray(int_array, dtype='l') is not np.asarray(int_array, dtype='l') - assert np.asarray(int_array, dtype='l').base is np.asarray(int_array, dtype='l').base + assert long_int_array is not np.asarray(int_array, dtype='l') + assert long_int_array.base is np.asarray(int_array, dtype='l').base equivalent_requirement = np.dtype('i', metadata={'spam': True}) - annotated_int_array_alt = np.asarray(annotated_int_array, dtype=equivalent_requirement) + annotated_int_array_alt = np.asarray(annotated_int_array, + dtype=equivalent_requirement) # The descriptors are equivalent, but we have created # distinct dtype instances. assert unequal_type == equivalent_requirement -- cgit v1.2.1 From b1a8ff8fa73b744416e12cdd4bb70594717b5336 Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Sun, 17 Jul 2022 13:12:31 -0500 Subject: Add release note and further clarify tests. --- numpy/array_api/tests/test_asarray.py | 43 +++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 12 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py index 4a9dd77a0..5c269823f 100644 --- a/numpy/array_api/tests/test_asarray.py +++ b/numpy/array_api/tests/test_asarray.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np @@ -24,22 +26,39 @@ def test_dtype_identity(): annotated_int_array = np.asarray(int_array, dtype=unequal_type) assert annotated_int_array is not int_array assert annotated_int_array.base is int_array - - # These ``asarray()`` calls may produce a new view or a copy, - # but never the same object. - long_int_array = np.asarray(int_array, dtype='l') - assert long_int_array is not int_array - assert np.asarray(int_array, dtype='q') is not int_array - assert np.asarray(long_int_array, dtype='q') is not long_int_array - assert long_int_array is not np.asarray(int_array, dtype='l') - assert long_int_array.base is np.asarray(int_array, dtype='l').base - + # Create an equivalent descriptor with a new and distinct dtype instance. equivalent_requirement = np.dtype('i', metadata={'spam': True}) annotated_int_array_alt = np.asarray(annotated_int_array, dtype=equivalent_requirement) - # The descriptors are equivalent, but we have created - # distinct dtype instances. assert unequal_type == equivalent_requirement assert unequal_type is not equivalent_requirement assert annotated_int_array_alt is not annotated_int_array assert annotated_int_array_alt.dtype is equivalent_requirement + + # Check the same logic for a pair of C types whose equivalence may vary + # between computing environments. + # Find an equivalent pair. + integer_type_codes = ('i', 'l', 'q') + integer_dtypes = [np.dtype(code) for code in integer_type_codes] + typeA = None + typeB = None + for typeA, typeB in itertools.permutations(integer_dtypes, r=2): + if typeA == typeB: + assert typeA is not typeB + break + assert isinstance(typeA, np.dtype) and isinstance(typeB, np.dtype) + + # These ``asarray()`` calls may produce a new view or a copy, + # but never the same object. + long_int_array = np.asarray(int_array, dtype='l') + long_long_int_array = np.asarray(int_array, dtype='q') + assert long_int_array is not int_array + assert long_long_int_array is not int_array + assert np.asarray(long_int_array, dtype='q') is not long_int_array + array_a = np.asarray(int_array, dtype=typeA) + assert typeA == typeB + assert typeA is not typeB + assert array_a.dtype is typeA + assert array_a is not np.asarray(array_a, dtype=typeB) + assert np.asarray(array_a, dtype=typeB).dtype is typeB + assert array_a is np.asarray(array_a, dtype=typeB).base -- cgit v1.2.1 From 7c8e1134ae86f8a8c002e1068337bec63ddf7f0d Mon Sep 17 00:00:00 2001 From: Ikko Ashimine Date: Tue, 2 Aug 2022 19:26:42 +0900 Subject: MAINT: fix typo in test_array_object test description (#22071) --- numpy/array_api/tests/test_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index ba9223532..f6efacefa 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -364,7 +364,7 @@ def test_array_keys_use_private_array(): in __getitem__(). This is achieved by passing array_api arrays with 0-sized dimensions, which NumPy-proper treats erroneously - not sure why! - TODO: Find and use appropiate __setitem__() case. + TODO: Find and use appropriate __setitem__() case. """ a = ones((0, 0), dtype=bool_) assert a[a].shape == (0,) -- cgit v1.2.1 From 0e960b985843ff99db06f89eadaa9f387b5a65f8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Sep 2022 11:13:22 -0500 Subject: BUG: Fix the implementation of numpy.array_api.vecdot (#21928) * Fix the implementation of numpy.array_api.vecdot See https://data-apis.org/array-api/latest/API_specification/generated/signatures.linear_algebra_functions.vecdot.html * Use moveaxis + matmul instead of einsum in vecdot --- numpy/array_api/linalg.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index a4a2f23e4..d214046ef 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -379,7 +379,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') - return tensordot(x1, x2, axes=((axis,), (axis,))) + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) # Note: the name here is different from norm(). The array API norm is split -- cgit v1.2.1 From 4d4fca0cef2e5594a50c1d10a9db6f21aaba0a58 Mon Sep 17 00:00:00 2001 From: "M. Eric Irrgang" Date: Tue, 13 Sep 2022 09:20:45 +0300 Subject: TST: Move new `asarray` test to a more appropriate place. (#22251) As noted at #21995 (comment), the new test from #21995 was placed in a directory intended for the Array API, and unrelated to the change. * Consolidate test_dtype_identity into an existing test file. Remove `test_asarray.py`. Create a new `TestAsArray` suite in `test_array_coercion.py` * Linting. Wrap some comments that got too long after function became a method (with additional indentation). --- numpy/array_api/tests/test_asarray.py | 64 ----------------------------------- 1 file changed, 64 deletions(-) delete mode 100644 numpy/array_api/tests/test_asarray.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_asarray.py b/numpy/array_api/tests/test_asarray.py deleted file mode 100644 index 5c269823f..000000000 --- a/numpy/array_api/tests/test_asarray.py +++ /dev/null @@ -1,64 +0,0 @@ -import itertools - -import numpy as np - - -def test_dtype_identity(): - """Confirm the intended behavior for ``asarray`` results. - - The result of ``asarray()`` should have the dtype provided through the - keyword argument, when used. This forces unique array handles to be - produced for unique np.dtype objects, but (for equivalent dtypes), the - underlying data (the base object) is shared with the original array object. - - Ref https://github.com/numpy/numpy/issues/1468 - """ - int_array = np.array([1, 2, 3], dtype='i') - assert np.asarray(int_array) is int_array - - # The character code resolves to the singleton dtype object provided - # by the numpy package. - assert np.asarray(int_array, dtype='i') is int_array - - # Derive a dtype from n.dtype('i'), but add a metadata object to force - # the dtype to be distinct. - unequal_type = np.dtype('i', metadata={'spam': True}) - annotated_int_array = np.asarray(int_array, dtype=unequal_type) - assert annotated_int_array is not int_array - assert annotated_int_array.base is int_array - # Create an equivalent descriptor with a new and distinct dtype instance. - equivalent_requirement = np.dtype('i', metadata={'spam': True}) - annotated_int_array_alt = np.asarray(annotated_int_array, - dtype=equivalent_requirement) - assert unequal_type == equivalent_requirement - assert unequal_type is not equivalent_requirement - assert annotated_int_array_alt is not annotated_int_array - assert annotated_int_array_alt.dtype is equivalent_requirement - - # Check the same logic for a pair of C types whose equivalence may vary - # between computing environments. - # Find an equivalent pair. - integer_type_codes = ('i', 'l', 'q') - integer_dtypes = [np.dtype(code) for code in integer_type_codes] - typeA = None - typeB = None - for typeA, typeB in itertools.permutations(integer_dtypes, r=2): - if typeA == typeB: - assert typeA is not typeB - break - assert isinstance(typeA, np.dtype) and isinstance(typeB, np.dtype) - - # These ``asarray()`` calls may produce a new view or a copy, - # but never the same object. - long_int_array = np.asarray(int_array, dtype='l') - long_long_int_array = np.asarray(int_array, dtype='q') - assert long_int_array is not int_array - assert long_long_int_array is not int_array - assert np.asarray(long_int_array, dtype='q') is not long_int_array - array_a = np.asarray(int_array, dtype=typeA) - assert typeA == typeB - assert typeA is not typeB - assert array_a.dtype is typeA - assert array_a is not np.asarray(array_a, dtype=typeB) - assert np.asarray(array_a, dtype=typeB).dtype is typeB - assert array_a is np.asarray(array_a, dtype=typeB).base -- cgit v1.2.1 From 6579448e6c76b2ad4487903688ac9fe3c02ff39c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Sun, 2 Oct 2022 12:35:59 +0100 Subject: Add `__array_api_version__` to `numpy.array_api` namespace --- numpy/array_api/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index bbe2fdce2..5e58ee0a8 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -121,7 +121,9 @@ warnings.warn( "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2 ) -__all__ = [] +__array_api_version__ = "2021.12" + +__all__ = ["__array_api_version__"] from ._constants import e, inf, nan, pi -- cgit v1.2.1 From de0521fc22e641be5e819a2fec785c6f89ebca8c Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Fri, 25 Feb 2022 18:05:39 +0100 Subject: MAINT: Let `ndarray.__imatmul__` handle inplace matrix multiplication in the array-api --- numpy/array_api/_array_object.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c4746fad9..592ca09df 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -850,23 +850,13 @@ class Array: """ Performs the operation __imatmul__. """ - # Note: NumPy does not implement __imatmul__. - # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - - # __imatmul__ can only be allowed when it would not change the shape - # of self. - other_shape = other.shape - if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") - if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: - raise ValueError("@= cannot change the shape of the input array") - self._array[:] = self._array.__matmul__(other._array) - return self + res = self._array.__imatmul__(other._array) + return self.__class__._new(res) def __rmatmul__(self: Array, other: Array, /) -> Array: """ -- cgit v1.2.1 From a6740500576475a43a7121087e9f5f96d48ef1d2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Dec 2022 17:48:32 -0700 Subject: DOC: Some updates to the array_api compat document (#22747) * Add reshape differences to the array API compat document * Add an item to the array API compat document about reverse broadcasting * Make some wording easier to read --- numpy/array_api/_manipulation_functions.py | 1 + 1 file changed, 1 insertion(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index 4f2114ff5..7991f46a2 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -52,6 +52,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return Array._new(np.transpose(x._array, axes)) +# Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. -- cgit v1.2.1 From 8816c76631af79c46a8cf33ac1a8a79b2717c9ac Mon Sep 17 00:00:00 2001 From: Francesc Elies Date: Tue, 21 Feb 2023 14:50:04 +0100 Subject: TYP,MAINT: Add a missing explicit Any parameter --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c4746fad9..eee117be6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -56,7 +56,7 @@ class Array: functions, such as asarray(). """ - _array: np.ndarray + _array: np.ndarray[Any, Any] # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. -- cgit v1.2.1 From f07d55b27671a4575e3b9b2fc7ca9ec897d4db9e Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Sun, 26 Feb 2023 07:56:47 +0000 Subject: add support for xp.take --- numpy/array_api/__init__.py | 4 ++++ numpy/array_api/_indexing_functions.py | 18 ++++++++++++++++++ numpy/array_api/tests/test_indexing_functions.py | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 numpy/array_api/_indexing_functions.py create mode 100644 numpy/array_api/tests/test_indexing_functions.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 5e58ee0a8..e154b9952 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -333,6 +333,10 @@ __all__ += [ "trunc", ] +from ._indexing_functions import take + +__all__ += ["take"] + # linalg is an extension in the array API spec, which is a sub-namespace. Only # a subset of functions in it are imported into the top-level namespace. from . import linalg diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py new file mode 100644 index 000000000..ba56bcd6f --- /dev/null +++ b/numpy/array_api/_indexing_functions.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _integer_dtypes + +import numpy as np + +def take(x: Array, indices: Array, /, *, axis: int) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take `. + + See its docstring for more information. + """ + if indices.dtype not in _integer_dtypes: + raise TypeError("Only integer dtypes are allowed in indexing") + if indices.ndim != 1: + raise ValueError("Only 1-dim indices array is supported") + return Array._new(np.take(x._array, indices._array, axis=axis)) diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py new file mode 100644 index 000000000..26667e32f --- /dev/null +++ b/numpy/array_api/tests/test_indexing_functions.py @@ -0,0 +1,24 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "x, indices, axis, expected", + [ + ([2, 3], [1, 1, 0], 0, [3, 3, 2]), + ([2, 3], [1, 1, 0], -1, [3, 3, 2]), + ([[2, 3]], [1], -1, [[3]]), + ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]), + ], +) +def test_stable_desc_argsort(x, indices, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(x) + indices = xp.asarray(indices) + out = xp.take(x, indices, axis=axis) + assert xp.all(out == xp.asarray(expected)) -- cgit v1.2.1 From 786bd366b22675b1e4067653d4729031c258f35f Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Sun, 26 Feb 2023 08:19:45 +0000 Subject: rename test function --- numpy/array_api/tests/test_indexing_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py index 26667e32f..9e05c6386 100644 --- a/numpy/array_api/tests/test_indexing_functions.py +++ b/numpy/array_api/tests/test_indexing_functions.py @@ -12,7 +12,7 @@ from numpy import array_api as xp ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]), ], ) -def test_stable_desc_argsort(x, indices, axis, expected): +def test_take_function(x, indices, axis, expected): """ Indices respect relative order of a descending stable-sort -- cgit v1.2.1 From e570c6f5ff6f5aa966193d51560d3cde30fc09bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 14 Mar 2023 11:24:19 +0100 Subject: MAINT: cleanup unused Python3.8-only code and references --- numpy/array_api/_typing.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index dfa87b358..3f9b7186a 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -17,14 +17,12 @@ __all__ = [ "PyCapsule", ] -import sys from typing import ( Any, Literal, Sequence, Type, Union, - TYPE_CHECKING, TypeVar, Protocol, ) @@ -51,21 +49,20 @@ class NestedSequence(Protocol[_T_co]): def __len__(self, /) -> int: ... Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype + +Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +]] + SupportsBufferProtocol = Any PyCapsule = Any -- cgit v1.2.1