summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-12-11 13:22:46 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2020-12-11 13:59:51 +0100
commit9d36f09cae3b09e61105e9222e0b56c0b1546a0c (patch)
tree5377501699063abb8e653c1add379ed2a7bca453 /numpy
parent1237be83356e20f568f28e54bc0099f5acd3e2db (diff)
downloadnumpy-9d36f09cae3b09e61105e9222e0b56c0b1546a0c.tar.gz
ENH: Add dtype-support for `np.flatiter`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi17
1 files changed, 10 insertions, 7 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index ac21c9907..79c53adc7 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -955,19 +955,22 @@ _ArrayLikeInt = Union[
_FlatIterSelf = TypeVar("_FlatIterSelf", bound=flatiter)
-class flatiter(Generic[_ArraySelf]):
+class flatiter(Generic[_NdArraySubClass]):
@property
- def base(self) -> _ArraySelf: ...
+ def base(self) -> _NdArraySubClass: ...
@property
def coords(self) -> _Shape: ...
@property
def index(self) -> int: ...
- def copy(self) -> _ArraySelf: ...
+ def copy(self) -> _NdArraySubClass: ...
def __iter__(self: _FlatIterSelf) -> _FlatIterSelf: ...
- def __next__(self) -> generic: ...
+ def __next__(self: flatiter[ndarray[Any, dtype[_ScalarType]]]) -> _ScalarType: ...
def __len__(self) -> int: ...
@overload
- def __getitem__(self, key: Union[int, integer]) -> generic: ...
+ def __getitem__(
+ self: flatiter[ndarray[Any, dtype[_ScalarType]]],
+ key: Union[int, integer],
+ ) -> _ScalarType: ...
@overload
def __getitem__(
self, key: Union[_ArrayLikeInt, slice, ellipsis],
@@ -1487,7 +1490,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
def byteswap(self: _ArraySelf, inplace: bool = ...) -> _ArraySelf: ...
def fill(self, value: Any) -> None: ...
@property
- def flat(self: _ArraySelf) -> flatiter[_ArraySelf]: ...
+ def flat(self: _NdArraySubClass) -> flatiter[_NdArraySubClass]: ...
@overload
def item(self, *args: int) -> Any: ...
@overload
@@ -1668,7 +1671,7 @@ class generic(_ArrayOrScalarCommon):
def strides(self) -> Tuple[()]: ...
def byteswap(self: _ScalarType, inplace: Literal[False] = ...) -> _ScalarType: ...
@property
- def flat(self) -> flatiter[ndarray]: ...
+ def flat(self: _ScalarType) -> flatiter[ndarray[Any, dtype[_ScalarType]]]: ...
def item(
self: _ScalarType,
__args: Union[Literal[0], Tuple[()], Tuple[Literal[0]]] = ...,