summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-12-11 13:22:53 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2020-12-11 13:31:39 +0100
commit1237be83356e20f568f28e54bc0099f5acd3e2db (patch)
tree63455f2c8dce415c4c34385c62dc8778ca1b155a
parent9e26d1d2be7a961a16f8fa9ff7820c33b25415e2 (diff)
downloadnumpy-1237be83356e20f568f28e54bc0099f5acd3e2db.tar.gz
ENH: Add dtype-support for the `__array__` protocol
-rw-r--r--numpy/__init__.pyi16
-rw-r--r--numpy/typing/_array_like.py21
2 files changed, 26 insertions, 11 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 83afd2e49..ac21c9907 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -971,8 +971,11 @@ class flatiter(Generic[_ArraySelf]):
@overload
def __getitem__(
self, key: Union[_ArrayLikeInt, slice, ellipsis],
- ) -> _ArraySelf: ...
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
+ ) -> _NdArraySubClass: ...
+ @overload
+ def __array__(self: flatiter[ndarray[Any, _DType]], __dtype: None = ...) -> ndarray[Any, _DType]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
_OrderKACF = Optional[Literal["K", "A", "C", "F"]]
_OrderACF = Optional[Literal["A", "C", "F"]]
@@ -1004,7 +1007,6 @@ class _ArrayOrScalarCommon:
def itemsize(self) -> int: ...
@property
def nbytes(self) -> int: ...
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
def __bool__(self) -> bool: ...
def __bytes__(self) -> bytes: ...
def __str__(self) -> str: ...
@@ -1468,6 +1470,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
strides: _ShapeLike = ...,
order: _OrderKACF = ...,
) -> _ArraySelf: ...
+ @overload
+ def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
@property
def ctypes(self) -> _ctypes: ...
@property
@@ -1646,6 +1652,10 @@ _NBit_co2 = TypeVar("_NBit_co2", covariant=True, bound=NBitBase)
class generic(_ArrayOrScalarCommon):
@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
+ @overload
+ def __array__(self: _ScalarType, __dtype: None = ...) -> ndarray[Any, dtype[_ScalarType]]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
@property
def base(self) -> None: ...
@property
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index a1a604239..4ea6974b2 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import sys
-from typing import Any, overload, Sequence, TYPE_CHECKING, Union
+from typing import Any, overload, Sequence, TYPE_CHECKING, Union, TypeVar
-from numpy import ndarray
+from numpy import ndarray, dtype
from ._scalars import _ScalarLike
from ._dtype_like import DTypeLike
@@ -16,12 +18,15 @@ else:
else:
HAVE_PROTOCOL = True
+_DType = TypeVar("_DType", bound="dtype[Any]")
+
if TYPE_CHECKING or HAVE_PROTOCOL:
- class _SupportsArray(Protocol):
- @overload
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
- @overload
- def __array__(self, dtype: DTypeLike = ...) -> ndarray: ...
+ # The `_SupportsArray` protocol only cares about the default dtype
+ # (i.e. `dtype=None`) ofthe to-be returned array.
+ # Concrete implementations of the protocol are responsible for adding
+ # any and all remaining overloads
+ class _SupportsArray(Protocol[_DType]):
+ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
else:
_SupportsArray = Any
@@ -36,5 +41,5 @@ ArrayLike = Union[
_ScalarLike,
Sequence[_ScalarLike],
Sequence[Sequence[Any]], # TODO: Wait for support for recursive types
- _SupportsArray,
+ "_SupportsArray[Any]",
]