summaryrefslogtreecommitdiff
path: root/numpy/array_api/_typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_typing.py')
-rw-r--r--numpy/array_api/_typing.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
new file mode 100644
index 000000000..dfa87b358
--- /dev/null
+++ b/numpy/array_api/_typing.py
@@ -0,0 +1,74 @@
+"""
+This file defines the types for type annotations.
+
+These names aren't part of the module namespace, but they are used in the
+annotations in the function signatures. The functions in the module are only
+valid for inputs that match the given type annotations.
+"""
+
+from __future__ import annotations
+
+__all__ = [
+ "Array",
+ "Device",
+ "Dtype",
+ "SupportsDLPack",
+ "SupportsBufferProtocol",
+ "PyCapsule",
+]
+
+import sys
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
+
+from ._array_object import Array
+from numpy import (
+ dtype,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+)
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
+
+Device = Literal["cpu"]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
+
+SupportsBufferProtocol = Any
+PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...