summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2021-10-22 15:37:34 -0500
committerGitHub <noreply@github.com>2021-10-22 15:37:34 -0500
commit203936c1e0950637fc2852b7f70852c49bcea291 (patch)
tree6aa5277e3fee7f75cd61a973b34981c831acaacd /numpy/array_api
parent31bb98bab3c90e781d94473dde0c0c952a3dc213 (diff)
parentd74bea12d19dd92c9cf07cac35e94d45fb331832 (diff)
downloadnumpy-203936c1e0950637fc2852b7f70852c49bcea291.tar.gz
Merge pull request #20129 from BvB93/api
ENH: Misc typing improvements to `np.array_api`
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py4
-rw-r--r--numpy/array_api/_typing.py26
2 files changed, 22 insertions, 8 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 830319e8c..ef66c5efd 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -29,7 +29,7 @@ from ._dtypes import (
_dtype_categories,
)
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype
@@ -382,7 +382,7 @@ class Array:
def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
- ) -> object:
+ ) -> Any:
if api_version is not None and not api_version.startswith("2021."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return array_api
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 519e8463c..dfa87b358 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
+from __future__ import annotations
+
__all__ = [
"Array",
"Device",
@@ -16,7 +18,16 @@ __all__ = [
]
import sys
-from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
from ._array_object import Array
from numpy import (
@@ -33,10 +44,11 @@ from numpy import (
float64,
)
-# This should really be recursive, but that isn't supported yet. See the
-# similar comment in numpy/typing/_array_like.py
-_T = TypeVar("_T")
-NestedSequence = Sequence[Sequence[_T]]
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
Device = Literal["cpu"]
if TYPE_CHECKING or sys.version_info >= (3, 9):
@@ -55,6 +67,8 @@ if TYPE_CHECKING or sys.version_info >= (3, 9):
else:
Dtype = dtype
-SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...