summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2022-01-02 14:57:30 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2022-01-02 14:57:30 +0100
commit5b9e7fb7e9a7cf4fa67aa37dd634e8996560d227 (patch)
treecb6dc01bc825af7f31f23efe687607dd15f93c4d
parent81c5acfedc887b6bdcbc346c491f2e9d1c373803 (diff)
downloadnumpy-5b9e7fb7e9a7cf4fa67aa37dd634e8996560d227.tar.gz
TYP,MAINT: Allow `ndindex` to accept integer tuples
-rw-r--r--numpy/__init__.pyi3
-rw-r--r--numpy/typing/tests/data/fail/index_tricks.pyi1
-rw-r--r--numpy/typing/tests/data/reveal/index_tricks.pyi2
3 files changed, 6 insertions, 0 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index ee99ffb36..0facafaa8 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -3314,6 +3314,9 @@ class ndenumerate(Generic[_ScalarType]):
def __iter__(self: _T) -> _T: ...
class ndindex:
+ @overload
+ def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ...
+ @overload
def __init__(self, *shape: SupportsIndex) -> None: ...
def __iter__(self: _T) -> _T: ...
def __next__(self) -> _Shape: ...
diff --git a/numpy/typing/tests/data/fail/index_tricks.pyi b/numpy/typing/tests/data/fail/index_tricks.pyi
index 2bf2337db..22f6f4a61 100644
--- a/numpy/typing/tests/data/fail/index_tricks.pyi
+++ b/numpy/typing/tests/data/fail/index_tricks.pyi
@@ -3,6 +3,7 @@ import numpy as np
AR_LIKE_i: list[int]
AR_LIKE_f: list[float]
+np.ndindex([1, 2, 3]) # E: No overload variant
np.unravel_index(AR_LIKE_f, (1, 2, 3)) # E: incompatible type
np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # E: No overload variant
np.mgrid[1] # E: Invalid index type
diff --git a/numpy/typing/tests/data/reveal/index_tricks.pyi b/numpy/typing/tests/data/reveal/index_tricks.pyi
index 365a8ee51..4018605ea 100644
--- a/numpy/typing/tests/data/reveal/index_tricks.pyi
+++ b/numpy/typing/tests/data/reveal/index_tricks.pyi
@@ -24,6 +24,8 @@ reveal_type(iter(np.ndenumerate(AR_i8))) # E: Iterator[Tuple[builtins.tuple[bui
reveal_type(iter(np.ndenumerate(AR_LIKE_f))) # E: Iterator[Tuple[builtins.tuple[builtins.int], {double}]]
reveal_type(iter(np.ndenumerate(AR_LIKE_U))) # E: Iterator[Tuple[builtins.tuple[builtins.int], str_]]
+reveal_type(np.ndindex(1, 2, 3)) # E: numpy.ndindex
+reveal_type(np.ndindex((1, 2, 3))) # E: numpy.ndindex
reveal_type(iter(np.ndindex(1, 2, 3))) # E: Iterator[builtins.tuple[builtins.int]]
reveal_type(next(np.ndindex(1, 2, 3))) # E: builtins.tuple[builtins.int]