summaryrefslogtreecommitdiff
path: root/numpy/array_api/_typing.py
diff options
context:
space:
mode:
authorGagandeep Singh <gdp.1807@gmail.com>2021-11-02 11:28:17 +0530
committerGagandeep Singh <gdp.1807@gmail.com>2021-11-02 11:28:17 +0530
commitc04509e86e97a69a0b5fcbeebdbec66faad3dbe0 (patch)
treeb5940db3ad46e55b88d566ec058007dc3c6b7e72 /numpy/array_api/_typing.py
parent56647dd47345a7fd24b4ee8d9d52025fcdc3b9ae (diff)
parentfae6fa47a3cf9b9c64af2f5bd11a3b644b1763d2 (diff)
downloadnumpy-c04509e86e97a69a0b5fcbeebdbec66faad3dbe0.tar.gz
resolved conflicts
Diffstat (limited to 'numpy/array_api/_typing.py')
-rw-r--r--numpy/array_api/_typing.py52
1 files changed, 41 insertions, 11 deletions
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index d530a91ae..dfa87b358 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
+from __future__ import annotations
+
__all__ = [
"Array",
"Device",
@@ -15,10 +17,21 @@ __all__ = [
"PyCapsule",
]
-from typing import Any, Sequence, Type, Union
+import sys
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
-from . import (
- Array,
+from ._array_object import Array
+from numpy import (
+ dtype,
int8,
int16,
int32,
@@ -31,14 +44,31 @@ from . import (
float64,
)
-# This should really be recursive, but that isn't supported yet. See the
-# similar comment in numpy/typing/_array_like.py
-NestedSequence = Sequence[Sequence[Any]]
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
+
+Device = Literal["cpu"]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
-Device = Any
-Dtype = Type[
- Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]]
-]
-SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...