diff options
-rw-r--r-- | numpy/__init__.pyi | 16 | ||||
-rw-r--r-- | numpy/typing/_array_like.py | 21 |
2 files changed, 26 insertions, 11 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 83afd2e49..ac21c9907 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -971,8 +971,11 @@ class flatiter(Generic[_ArraySelf]): @overload def __getitem__( self, key: Union[_ArrayLikeInt, slice, ellipsis], - ) -> _ArraySelf: ... - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... + ) -> _NdArraySubClass: ... + @overload + def __array__(self: flatiter[ndarray[Any, _DType]], __dtype: None = ...) -> ndarray[Any, _DType]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... _OrderKACF = Optional[Literal["K", "A", "C", "F"]] _OrderACF = Optional[Literal["A", "C", "F"]] @@ -1004,7 +1007,6 @@ class _ArrayOrScalarCommon: def itemsize(self) -> int: ... @property def nbytes(self) -> int: ... - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... def __bool__(self) -> bool: ... def __bytes__(self) -> bytes: ... def __str__(self) -> str: ... @@ -1468,6 +1470,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]): strides: _ShapeLike = ..., order: _OrderKACF = ..., ) -> _ArraySelf: ... + @overload + def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... @property def ctypes(self) -> _ctypes: ... @property @@ -1646,6 +1652,10 @@ _NBit_co2 = TypeVar("_NBit_co2", covariant=True, bound=NBitBase) class generic(_ArrayOrScalarCommon): @abstractmethod def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @overload + def __array__(self: _ScalarType, __dtype: None = ...) -> ndarray[Any, dtype[_ScalarType]]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... @property def base(self) -> None: ... @property diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index a1a604239..4ea6974b2 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import sys -from typing import Any, overload, Sequence, TYPE_CHECKING, Union +from typing import Any, overload, Sequence, TYPE_CHECKING, Union, TypeVar -from numpy import ndarray +from numpy import ndarray, dtype from ._scalars import _ScalarLike from ._dtype_like import DTypeLike @@ -16,12 +18,15 @@ else: else: HAVE_PROTOCOL = True +_DType = TypeVar("_DType", bound="dtype[Any]") + if TYPE_CHECKING or HAVE_PROTOCOL: - class _SupportsArray(Protocol): - @overload - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... - @overload - def __array__(self, dtype: DTypeLike = ...) -> ndarray: ... + # The `_SupportsArray` protocol only cares about the default dtype + # (i.e. `dtype=None`) ofthe to-be returned array. + # Concrete implementations of the protocol are responsible for adding + # any and all remaining overloads + class _SupportsArray(Protocol[_DType]): + def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... else: _SupportsArray = Any @@ -36,5 +41,5 @@ ArrayLike = Union[ _ScalarLike, Sequence[_ScalarLike], Sequence[Sequence[Any]], # TODO: Wait for support for recursive types - _SupportsArray, + "_SupportsArray[Any]", ] |