summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-05-12 14:48:27 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-05-12 15:35:19 +0200
commit4bd5fd2abf8f681d2af900696eca1e8a99860f83 (patch)
tree7f62f839598dcf2fc32b0ac8fed3e1da4b513f9f
parentae6960de23afc8b33e0dbfc3d5d361ab126d8734 (diff)
downloadnumpy-4bd5fd2abf8f681d2af900696eca1e8a99860f83.tar.gz
ENH: Add annotations for `ndarray.item`
-rw-r--r--numpy/__init__.pyi17
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 466af9110..d14b6de46 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1676,6 +1676,7 @@ _NumberType = TypeVar("_NumberType", bound=number[Any])
_BufferType = Union[ndarray, bytes, bytearray, memoryview]
_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
_2Tuple = Tuple[_T, _T]
_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"]
@@ -1686,6 +1687,9 @@ _ArrayComplex_co = _ArrayND[Union[bool_, integer[Any], floating[Any], complexflo
_ArrayNumber_co = _ArrayND[Union[bool_, number[Any]]]
_ArrayTD64_co = _ArrayND[Union[bool_, integer[Any], timedelta64]]
+class _SupportsItem(Protocol[_T_co]):
+ def item(self, __args: Any) -> _T_co: ...
+
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@property
def base(self) -> Optional[ndarray]: ...
@@ -1728,10 +1732,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def fill(self, value: Any) -> None: ...
@property
def flat(self: _NdArraySubClass) -> flatiter[_NdArraySubClass]: ...
+
+ # Use the same output type as that of the underlying `generic`
@overload
- def item(self, *args: SupportsIndex) -> Any: ...
+ def item(
+ self: ndarray[Any, dtype[_SupportsItem[_T]]], # type: ignore[type-var]
+ *args: SupportsIndex,
+ ) -> _T: ...
@overload
- def item(self, __args: Tuple[SupportsIndex, ...]) -> Any: ...
+ def item(
+ self: ndarray[Any, dtype[_SupportsItem[_T]]], # type: ignore[type-var]
+ __args: Tuple[SupportsIndex, ...],
+ ) -> _T: ...
+
@overload
def itemset(self, __value: Any) -> None: ...
@overload