summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-04-21 18:25:39 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-04-30 22:09:51 +0200
commit44c3e1f2ffbccaa5e0877deb282a0ce49ed04c35 (patch)
treea06a40d913287305cb0718ece4583a9524b0f8d9
parent669826524248d15ea4f9383cf5a230685351f06c (diff)
downloadnumpy-44c3e1f2ffbccaa5e0877deb282a0ce49ed04c35.tar.gz
MAINT: Remove unsafe unions from `np.lib.index_tricks`
-rw-r--r--numpy/lib/index_tricks.pyi23
-rw-r--r--numpy/typing/tests/data/fail/index_tricks.py2
-rw-r--r--numpy/typing/tests/data/reveal/index_tricks.py19
3 files changed, 30 insertions, 14 deletions
diff --git a/numpy/lib/index_tricks.pyi b/numpy/lib/index_tricks.pyi
index 3e5bc1adb..e486fe8f2 100644
--- a/numpy/lib/index_tricks.pyi
+++ b/numpy/lib/index_tricks.pyi
@@ -17,6 +17,7 @@ from numpy import (
ndindex as ndindex,
ndarray,
dtype,
+ integer,
str_,
bytes_,
bool_,
@@ -33,7 +34,6 @@ from numpy.typing import (
_NestedSequence,
_RecursiveSequence,
_ArrayND,
- _ArrayOrScalar,
_ArrayLikeInt,
# DTypes
@@ -57,18 +57,33 @@ _ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any])
__all__: List[str]
+@overload
+def unravel_index( # type: ignore[misc]
+ indices: Union[int, integer[Any]],
+ shape: _ShapeLike,
+ order: _OrderCF = ...
+) -> Tuple[intp, ...]: ...
+@overload
def unravel_index(
indices: _ArrayLikeInt,
shape: _ShapeLike,
order: _OrderCF = ...
-) -> Tuple[_ArrayOrScalar[intp], ...]: ...
+) -> Tuple[_ArrayND[intp], ...]: ...
+@overload
+def ravel_multi_index( # type: ignore[misc]
+ multi_index: Sequence[Union[int, integer[Any]]],
+ dims: _ShapeLike,
+ mode: Union[_ModeKind, Tuple[_ModeKind, ...]] = ...,
+ order: _OrderCF = ...
+) -> intp: ...
+@overload
def ravel_multi_index(
- multi_index: ArrayLike,
+ multi_index: Sequence[_ArrayLikeInt],
dims: _ShapeLike,
mode: Union[_ModeKind, Tuple[_ModeKind, ...]] = ...,
order: _OrderCF = ...
-) -> _ArrayOrScalar[intp]: ...
+) -> _ArrayND[intp]: ...
@overload
def ix_(*args: _NestedSequence[_SupportsDType[_DType]]) -> Tuple[ndarray[Any, _DType], ...]: ...
diff --git a/numpy/typing/tests/data/fail/index_tricks.py b/numpy/typing/tests/data/fail/index_tricks.py
index cbc43fd54..c508bf3ae 100644
--- a/numpy/typing/tests/data/fail/index_tricks.py
+++ b/numpy/typing/tests/data/fail/index_tricks.py
@@ -5,7 +5,7 @@ AR_LIKE_i: List[int]
AR_LIKE_f: List[float]
np.unravel_index(AR_LIKE_f, (1, 2, 3)) # E: incompatible type
-np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # E: incompatible type
+np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # E: No overload variant
np.mgrid[1] # E: Invalid index type
np.mgrid[...] # E: Invalid index type
np.ogrid[1] # E: Invalid index type
diff --git a/numpy/typing/tests/data/reveal/index_tricks.py b/numpy/typing/tests/data/reveal/index_tricks.py
index ec2013025..863d60220 100644
--- a/numpy/typing/tests/data/reveal/index_tricks.py
+++ b/numpy/typing/tests/data/reveal/index_tricks.py
@@ -27,15 +27,16 @@ reveal_type(iter(np.ndenumerate(AR_LIKE_U))) # E: Iterator[Tuple[builtins.tuple
reveal_type(iter(np.ndindex(1, 2, 3))) # E: Iterator[builtins.tuple[builtins.int]]
reveal_type(next(np.ndindex(1, 2, 3))) # E: builtins.tuple[builtins.int]
-reveal_type(np.unravel_index([22, 41, 37], (7, 6))) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]]
-reveal_type(np.unravel_index([31, 41, 13], (7, 6), order="F")) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]]
-reveal_type(np.unravel_index(1621, (6, 7, 8, 9))) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]]
-
-reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]
-reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6), order="F")) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]
-reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 6), mode="clip")) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]
-reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=("clip", "wrap"))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]
-reveal_type(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]
+reveal_type(np.unravel_index([22, 41, 37], (7, 6))) # E: tuple[numpy.ndarray[Any, numpy.dtype[{intp}]]]
+reveal_type(np.unravel_index([31, 41, 13], (7, 6), order="F")) # E: tuple[numpy.ndarray[Any, numpy.dtype[{intp}]]]
+reveal_type(np.unravel_index(1621, (6, 7, 8, 9))) # E: tuple[{intp}]
+
+reveal_type(np.ravel_multi_index([[1]], (7, 6))) # E: numpy.ndarray[Any, numpy.dtype[{intp}]]
+reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6))) # E: {intp}
+reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6), order="F")) # E: {intp}
+reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 6), mode="clip")) # E: {intp}
+reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=("clip", "wrap"))) # E: {intp}
+reveal_type(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9))) # E: {intp}
reveal_type(np.mgrid[1:1:2]) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(np.mgrid[1:1:2, None:10]) # E: numpy.ndarray[Any, numpy.dtype[Any]]