From 4d23ebeb068c8d6ba6edfc11d32ab2af8bb89c74 Mon Sep 17 00:00:00 2001 From: Alessia Marcolini <98marcolini@gmail.com> Date: Fri, 8 Oct 2021 09:49:11 +0000 Subject: MAINT: remove unused imports --- numpy/array_api/tests/test_creation_functions.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 3cb8865cd..7b633eaf1 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -8,30 +8,15 @@ from .._creation_functions import ( empty, empty_like, eye, - from_dlpack, full, full_like, linspace, - meshgrid, ones, ones_like, zeros, zeros_like, ) from .._array_object import Array -from .._dtypes import ( - _all_dtypes, - _boolean_dtypes, - _floating_dtypes, - _integer_dtypes, - _integer_or_boolean_dtypes, - _numeric_dtypes, - int8, - int16, - int32, - int64, - uint64, -) def test_asarray_errors(): -- cgit v1.2.1 From f931a434839222bb00282a432d6d6a0c2c52eb7d Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:10:47 +0200 Subject: ENH: Replace `NestedSequence` with a proper nested sequence protocol --- numpy/array_api/_typing.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 519e8463c..5e980b16f 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ +from __future__ import annotations + __all__ = [ "Array", "Device", @@ -16,7 +18,16 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar +from typing import ( + Any, + Literal, + Sequence, + Type, + Union, + TYPE_CHECKING, + TypeVar, + Protocol, +) from ._array_object import Array from numpy import ( @@ -33,10 +44,11 @@ from numpy import ( float64, ) -# This should really be recursive, but that isn't supported yet. See the -# similar comment in numpy/typing/_array_like.py -_T = TypeVar("_T") -NestedSequence = Sequence[Sequence[_T]] +_T_co = TypeVar("_T_co", covariant=True) + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): -- cgit v1.2.1 From 3952e8f1390629078fdb229236b3b1ce40140c32 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:11:06 +0200 Subject: ENH: Change `SupportsDLPack` into a protocol --- numpy/array_api/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5e980b16f..dfa87b358 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -67,6 +67,8 @@ if TYPE_CHECKING or sys.version_info >= (3, 9): else: Dtype = dtype -SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any + +class SupportsDLPack(Protocol): + def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... -- cgit v1.2.1 From d74bea12d19dd92c9cf07cac35e94d45fb331832 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:13:39 +0200 Subject: MAINT: Replace the `__array_namespace__` return type with `Any` Replace `object` as it cannot be used for expressing the objects in the array namespace. --- numpy/array_api/_array_object.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 830319e8c..ef66c5efd 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -382,7 +382,7 @@ class Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None - ) -> object: + ) -> Any: if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api -- cgit v1.2.1