diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-05-12 14:48:27 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-05-12 15:35:19 +0200 |
commit | 4bd5fd2abf8f681d2af900696eca1e8a99860f83 (patch) | |
tree | 7f62f839598dcf2fc32b0ac8fed3e1da4b513f9f | |
parent | ae6960de23afc8b33e0dbfc3d5d361ab126d8734 (diff) | |
download | numpy-4bd5fd2abf8f681d2af900696eca1e8a99860f83.tar.gz |
ENH: Add annotations for `ndarray.item`
-rw-r--r-- | numpy/__init__.pyi | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 466af9110..d14b6de46 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -1676,6 +1676,7 @@ _NumberType = TypeVar("_NumberType", bound=number[Any]) _BufferType = Union[ndarray, bytes, bytearray, memoryview] _T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) _2Tuple = Tuple[_T, _T] _Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"] @@ -1686,6 +1687,9 @@ _ArrayComplex_co = _ArrayND[Union[bool_, integer[Any], floating[Any], complexflo _ArrayNumber_co = _ArrayND[Union[bool_, number[Any]]] _ArrayTD64_co = _ArrayND[Union[bool_, integer[Any], timedelta64]] +class _SupportsItem(Protocol[_T_co]): + def item(self, __args: Any) -> _T_co: ... + class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): @property def base(self) -> Optional[ndarray]: ... @@ -1728,10 +1732,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): def fill(self, value: Any) -> None: ... @property def flat(self: _NdArraySubClass) -> flatiter[_NdArraySubClass]: ... + + # Use the same output type as that of the underlying `generic` @overload - def item(self, *args: SupportsIndex) -> Any: ... + def item( + self: ndarray[Any, dtype[_SupportsItem[_T]]], # type: ignore[type-var] + *args: SupportsIndex, + ) -> _T: ... @overload - def item(self, __args: Tuple[SupportsIndex, ...]) -> Any: ... + def item( + self: ndarray[Any, dtype[_SupportsItem[_T]]], # type: ignore[type-var] + __args: Tuple[SupportsIndex, ...], + ) -> _T: ... + @overload def itemset(self, __value: Any) -> None: ... @overload |