diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2020-09-22 01:15:35 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-02-25 14:05:51 +0100 |
commit | 668f1aa1082fb0316edd9e6069c4b16fb4a2d6c5 (patch) | |
tree | 413283afbfd2210de95d1f60af7516310237f10c /numpy | |
parent | 129f3f1b6b0154a175d2abd2289119c85bd705d9 (diff) | |
download | numpy-668f1aa1082fb0316edd9e6069c4b16fb4a2d6c5.tar.gz |
ENH: Add annotations for `np.lib.index_tricks`
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/__init__.pyi | 30 | ||||
-rw-r--r-- | numpy/lib/index_tricks.pyi | 179 |
2 files changed, 207 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 148a63583..fa498f508 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -496,9 +496,7 @@ nanstd: Any nansum: Any nanvar: Any nbytes: Any -ndenumerate: Any ndfromtxt: Any -ndindex: Any nditer: Any nested_iters: Any newaxis: Any @@ -2899,3 +2897,31 @@ class errstate(Generic[_CallType], ContextDecorator): __exc_value: Optional[BaseException], __traceback: Optional[TracebackType], ) -> None: ... + +class ndenumerate(Generic[_ScalarType]): + iter: flatiter[_ArrayND[_ScalarType]] + @overload + def __new__( + cls, arr: _NestedSequence[_SupportsArray[dtype[_ScalarType]]], + ) -> ndenumerate[_ScalarType]: ... + @overload + def __new__(cls, arr: _NestedSequence[str]) -> ndenumerate[str_]: ... + @overload + def __new__(cls, arr: _NestedSequence[bytes]) -> ndenumerate[bytes_]: ... + @overload + def __new__(cls, arr: _NestedSequence[bool]) -> ndenumerate[bool_]: ... + @overload + def __new__(cls, arr: _NestedSequence[int]) -> ndenumerate[int_]: ... + @overload + def __new__(cls, arr: _NestedSequence[float]) -> ndenumerate[float_]: ... + @overload + def __new__(cls, arr: _NestedSequence[complex]) -> ndenumerate[complex_]: ... + @overload + def __new__(cls, arr: _RecursiveSequence) -> ndenumerate[Any]: ... + def __next__(self: ndenumerate[_ScalarType]) -> Tuple[_Shape, _ScalarType]: ... + def __iter__(self: _T) -> _T: ... + +class ndindex: + def __init__(self, *shape: SupportsIndex) -> None: ... + def __iter__(self: _T) -> _T: ... + def __next__(self) -> _Shape: ... diff --git a/numpy/lib/index_tricks.pyi b/numpy/lib/index_tricks.pyi new file mode 100644 index 000000000..e602f9907 --- /dev/null +++ b/numpy/lib/index_tricks.pyi @@ -0,0 +1,179 @@ +import sys +from typing import ( + Any, + Tuple, + TypeVar, + Generic, + overload, + List, + Union, + Sequence, +) + +from numpy import ( + # Circumvent a naming conflict with `AxisConcatenator.matrix` + matrix as _Matrix, + ndenumerate as ndenumerate, + ndindex as ndindex, + ndarray, + dtype, + str_, + bytes_, + bool_, + int_, + float_, + complex_, + intp, + _OrderCF, + _ModeKind, +) +from numpy.typing import ( + # Arrays + ArrayLike, + _NestedSequence, + _RecursiveSequence, + _ArrayND, + _ArrayOrScalar, + _ArrayLikeInt, + + # DTypes + DTypeLike, + _SupportsDType, + + # Shapes + _ShapeLike, +) + +if sys.version_info >= (3, 8): + from typing import Literal, SupportsIndex +else: + from typing_extensions import Literal, SupportsIndex + +_T = TypeVar("_T") +_DType = TypeVar("_DType", bound=dtype[Any]) +_BoolType = TypeVar("_BoolType", Literal[True], Literal[False]) +_SliceOrTuple = TypeVar("_SliceOrTuple", bound=Union[slice, Tuple[slice, ...]]) +_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any]) + +__all__: List[str] + +def unravel_index( + indices: _ArrayLikeInt, + shape: _ShapeLike, + order: _OrderCF = ... +) -> Tuple[_ArrayOrScalar[intp], ...]: ... + +def ravel_multi_index( + multi_index: ArrayLike, + dims: _ShapeLike, + mode: Union[_ModeKind, Tuple[_ModeKind, ...]] = ..., + order: _OrderCF = ... +) -> _ArrayOrScalar[intp]: ... + +@overload +def ix_(*args: _NestedSequence[_SupportsDType[_DType]]) -> Tuple[ndarray[Any, _DType], ...]: ... +@overload +def ix_(*args: _NestedSequence[str]) -> Tuple[_ArrayND[str_], ...]: ... +@overload +def ix_(*args: _NestedSequence[bytes]) -> Tuple[_ArrayND[bytes_], ...]: ... +@overload +def ix_(*args: _NestedSequence[bool]) -> Tuple[_ArrayND[bool_], ...]: ... +@overload +def ix_(*args: _NestedSequence[int]) -> Tuple[_ArrayND[int_], ...]: ... +@overload +def ix_(*args: _NestedSequence[float]) -> Tuple[_ArrayND[float_], ...]: ... +@overload +def ix_(*args: _NestedSequence[complex]) -> Tuple[_ArrayND[complex_], ...]: ... +@overload +def ix_(*args: _RecursiveSequence) -> Tuple[_ArrayND[Any], ...]: ... + +class nd_grid(Generic[_BoolType]): + sparse: _BoolType + def __init__(self, sparse: _BoolType = ...) -> None: ... + @overload + def __getitem__( + self: nd_grid[Literal[False]], + key: Union[slice, Sequence[slice]], + ) -> _ArrayND[Any]: ... + @overload + def __getitem__( + self: nd_grid[Literal[True]], + key: Union[slice, Sequence[slice]], + ) -> List[_ArrayND[Any]]: ... + +class MGridClass(nd_grid[Literal[False]]): + def __init__(self) -> None: ... + +mgrid: MGridClass + +class OGridClass(nd_grid[Literal[True]]): + def __init__(self) -> None: ... + +ogrid: OGridClass + +class AxisConcatenator: + axis: int + matrix: bool + ndmin: int + trans1d: int + def __init__( + self, + axis: int = ..., + matrix: bool = ..., + ndmin: int = ..., + trans1d: int = ..., + ) -> None: ... + @staticmethod + @overload + def concatenate( # type: ignore[misc] + *a: ArrayLike, axis: SupportsIndex = ..., out: None = ... + ) -> _ArrayND[Any]: ... + @staticmethod + @overload + def concatenate( + *a: ArrayLike, axis: SupportsIndex = ..., out: _ArrayType = ... + ) -> _ArrayType: ... + @staticmethod + def makemat( + data: ArrayLike, dtype: DTypeLike = ..., copy: bool = ... + ) -> _Matrix: ... + + # TODO: Sort out this `__getitem__` method + def __getitem__(self, key: Any) -> Any: ... + +class RClass(AxisConcatenator): + axis: Literal[0] + matrix: Literal[False] + ndmin: Literal[1] + trans1d: Literal[-1] + def __init__(self) -> None: ... + +r_: RClass + +class CClass(AxisConcatenator): + axis: Literal[-1] + matrix: Literal[False] + ndmin: Literal[2] + trans1d: Literal[0] + def __init__(self) -> None: ... + +c_: CClass + +class IndexExpression(Generic[_BoolType]): + maketuple: _BoolType + def __init__(self, maketuple: _BoolType) -> None: ... + @overload + def __getitem__( # type: ignore[misc] + self: IndexExpression[Literal[True]], item: slice + ) -> Tuple[slice]: ... + @overload + def __getitem__(self, item: _SliceOrTuple) -> _SliceOrTuple: ... + +index_exp: IndexExpression[Literal[True]] +s_: IndexExpression[Literal[False]] + +def fill_diagonal(a: ndarray[Any, Any], val: Any, wrap: bool = ...) -> None: ... +def diag_indices(n: int, ndim: int = ...) -> Tuple[_ArrayND[int_], ...]: ... +def diag_indices_from(arr: ArrayLike) -> Tuple[_ArrayND[int_], ...]: ... + +# NOTE: see `numpy/__init__.pyi` for `ndenumerate` and `ndindex` |