summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-10-18 18:10:47 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-10-18 18:29:38 +0200
commitf931a434839222bb00282a432d6d6a0c2c52eb7d (patch)
tree169b83da827b1448c3c4833341256bb0cd903c59 /numpy/array_api
parentf4ef0fd5d97afc7788c46eb51889e90623f5577e (diff)
downloadnumpy-f931a434839222bb00282a432d6d6a0c2c52eb7d.tar.gz
ENH: Replace `NestedSequence` with a proper nested sequence protocol
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_typing.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 519e8463c..5e980b16f 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
+from __future__ import annotations
+
__all__ = [
"Array",
"Device",
@@ -16,7 +18,16 @@ __all__ = [
]
import sys
-from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
from ._array_object import Array
from numpy import (
@@ -33,10 +44,11 @@ from numpy import (
float64,
)
-# This should really be recursive, but that isn't supported yet. See the
-# similar comment in numpy/typing/_array_like.py
-_T = TypeVar("_T")
-NestedSequence = Sequence[Sequence[_T]]
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
Device = Literal["cpu"]
if TYPE_CHECKING or sys.version_info >= (3, 9):