diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-10-18 18:10:47 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-10-18 18:29:38 +0200 |
commit | f931a434839222bb00282a432d6d6a0c2c52eb7d (patch) | |
tree | 169b83da827b1448c3c4833341256bb0cd903c59 /numpy/array_api | |
parent | f4ef0fd5d97afc7788c46eb51889e90623f5577e (diff) | |
download | numpy-f931a434839222bb00282a432d6d6a0c2c52eb7d.tar.gz |
ENH: Replace `NestedSequence` with a proper nested sequence protocol
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/_typing.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 519e8463c..5e980b16f 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ +from __future__ import annotations + __all__ = [ "Array", "Device", @@ -16,7 +18,16 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar +from typing import ( + Any, + Literal, + Sequence, + Type, + Union, + TYPE_CHECKING, + TypeVar, + Protocol, +) from ._array_object import Array from numpy import ( @@ -33,10 +44,11 @@ from numpy import ( float64, ) -# This should really be recursive, but that isn't supported yet. See the -# similar comment in numpy/typing/_array_like.py -_T = TypeVar("_T") -NestedSequence = Sequence[Sequence[_T]] +_T_co = TypeVar("_T_co", covariant=True) + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): |