summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-09-22 01:15:35 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-02-25 14:05:51 +0100
commit668f1aa1082fb0316edd9e6069c4b16fb4a2d6c5 (patch)
tree413283afbfd2210de95d1f60af7516310237f10c /numpy
parent129f3f1b6b0154a175d2abd2289119c85bd705d9 (diff)
downloadnumpy-668f1aa1082fb0316edd9e6069c4b16fb4a2d6c5.tar.gz
ENH: Add annotations for `np.lib.index_tricks`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi30
-rw-r--r--numpy/lib/index_tricks.pyi179
2 files changed, 207 insertions, 2 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 148a63583..fa498f508 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -496,9 +496,7 @@ nanstd: Any
nansum: Any
nanvar: Any
nbytes: Any
-ndenumerate: Any
ndfromtxt: Any
-ndindex: Any
nditer: Any
nested_iters: Any
newaxis: Any
@@ -2899,3 +2897,31 @@ class errstate(Generic[_CallType], ContextDecorator):
__exc_value: Optional[BaseException],
__traceback: Optional[TracebackType],
) -> None: ...
+
+class ndenumerate(Generic[_ScalarType]):
+ iter: flatiter[_ArrayND[_ScalarType]]
+ @overload
+ def __new__(
+ cls, arr: _NestedSequence[_SupportsArray[dtype[_ScalarType]]],
+ ) -> ndenumerate[_ScalarType]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[str]) -> ndenumerate[str_]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[bytes]) -> ndenumerate[bytes_]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[bool]) -> ndenumerate[bool_]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[int]) -> ndenumerate[int_]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[float]) -> ndenumerate[float_]: ...
+ @overload
+ def __new__(cls, arr: _NestedSequence[complex]) -> ndenumerate[complex_]: ...
+ @overload
+ def __new__(cls, arr: _RecursiveSequence) -> ndenumerate[Any]: ...
+ def __next__(self: ndenumerate[_ScalarType]) -> Tuple[_Shape, _ScalarType]: ...
+ def __iter__(self: _T) -> _T: ...
+
+class ndindex:
+ def __init__(self, *shape: SupportsIndex) -> None: ...
+ def __iter__(self: _T) -> _T: ...
+ def __next__(self) -> _Shape: ...
diff --git a/numpy/lib/index_tricks.pyi b/numpy/lib/index_tricks.pyi
new file mode 100644
index 000000000..e602f9907
--- /dev/null
+++ b/numpy/lib/index_tricks.pyi
@@ -0,0 +1,179 @@
+import sys
+from typing import (
+ Any,
+ Tuple,
+ TypeVar,
+ Generic,
+ overload,
+ List,
+ Union,
+ Sequence,
+)
+
+from numpy import (
+ # Circumvent a naming conflict with `AxisConcatenator.matrix`
+ matrix as _Matrix,
+ ndenumerate as ndenumerate,
+ ndindex as ndindex,
+ ndarray,
+ dtype,
+ str_,
+ bytes_,
+ bool_,
+ int_,
+ float_,
+ complex_,
+ intp,
+ _OrderCF,
+ _ModeKind,
+)
+from numpy.typing import (
+ # Arrays
+ ArrayLike,
+ _NestedSequence,
+ _RecursiveSequence,
+ _ArrayND,
+ _ArrayOrScalar,
+ _ArrayLikeInt,
+
+ # DTypes
+ DTypeLike,
+ _SupportsDType,
+
+ # Shapes
+ _ShapeLike,
+)
+
+if sys.version_info >= (3, 8):
+ from typing import Literal, SupportsIndex
+else:
+ from typing_extensions import Literal, SupportsIndex
+
+_T = TypeVar("_T")
+_DType = TypeVar("_DType", bound=dtype[Any])
+_BoolType = TypeVar("_BoolType", Literal[True], Literal[False])
+_SliceOrTuple = TypeVar("_SliceOrTuple", bound=Union[slice, Tuple[slice, ...]])
+_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any])
+
+__all__: List[str]
+
+def unravel_index(
+ indices: _ArrayLikeInt,
+ shape: _ShapeLike,
+ order: _OrderCF = ...
+) -> Tuple[_ArrayOrScalar[intp], ...]: ...
+
+def ravel_multi_index(
+ multi_index: ArrayLike,
+ dims: _ShapeLike,
+ mode: Union[_ModeKind, Tuple[_ModeKind, ...]] = ...,
+ order: _OrderCF = ...
+) -> _ArrayOrScalar[intp]: ...
+
+@overload
+def ix_(*args: _NestedSequence[_SupportsDType[_DType]]) -> Tuple[ndarray[Any, _DType], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[str]) -> Tuple[_ArrayND[str_], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[bytes]) -> Tuple[_ArrayND[bytes_], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[bool]) -> Tuple[_ArrayND[bool_], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[int]) -> Tuple[_ArrayND[int_], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[float]) -> Tuple[_ArrayND[float_], ...]: ...
+@overload
+def ix_(*args: _NestedSequence[complex]) -> Tuple[_ArrayND[complex_], ...]: ...
+@overload
+def ix_(*args: _RecursiveSequence) -> Tuple[_ArrayND[Any], ...]: ...
+
+class nd_grid(Generic[_BoolType]):
+ sparse: _BoolType
+ def __init__(self, sparse: _BoolType = ...) -> None: ...
+ @overload
+ def __getitem__(
+ self: nd_grid[Literal[False]],
+ key: Union[slice, Sequence[slice]],
+ ) -> _ArrayND[Any]: ...
+ @overload
+ def __getitem__(
+ self: nd_grid[Literal[True]],
+ key: Union[slice, Sequence[slice]],
+ ) -> List[_ArrayND[Any]]: ...
+
+class MGridClass(nd_grid[Literal[False]]):
+ def __init__(self) -> None: ...
+
+mgrid: MGridClass
+
+class OGridClass(nd_grid[Literal[True]]):
+ def __init__(self) -> None: ...
+
+ogrid: OGridClass
+
+class AxisConcatenator:
+ axis: int
+ matrix: bool
+ ndmin: int
+ trans1d: int
+ def __init__(
+ self,
+ axis: int = ...,
+ matrix: bool = ...,
+ ndmin: int = ...,
+ trans1d: int = ...,
+ ) -> None: ...
+ @staticmethod
+ @overload
+ def concatenate( # type: ignore[misc]
+ *a: ArrayLike, axis: SupportsIndex = ..., out: None = ...
+ ) -> _ArrayND[Any]: ...
+ @staticmethod
+ @overload
+ def concatenate(
+ *a: ArrayLike, axis: SupportsIndex = ..., out: _ArrayType = ...
+ ) -> _ArrayType: ...
+ @staticmethod
+ def makemat(
+ data: ArrayLike, dtype: DTypeLike = ..., copy: bool = ...
+ ) -> _Matrix: ...
+
+ # TODO: Sort out this `__getitem__` method
+ def __getitem__(self, key: Any) -> Any: ...
+
+class RClass(AxisConcatenator):
+ axis: Literal[0]
+ matrix: Literal[False]
+ ndmin: Literal[1]
+ trans1d: Literal[-1]
+ def __init__(self) -> None: ...
+
+r_: RClass
+
+class CClass(AxisConcatenator):
+ axis: Literal[-1]
+ matrix: Literal[False]
+ ndmin: Literal[2]
+ trans1d: Literal[0]
+ def __init__(self) -> None: ...
+
+c_: CClass
+
+class IndexExpression(Generic[_BoolType]):
+ maketuple: _BoolType
+ def __init__(self, maketuple: _BoolType) -> None: ...
+ @overload
+ def __getitem__( # type: ignore[misc]
+ self: IndexExpression[Literal[True]], item: slice
+ ) -> Tuple[slice]: ...
+ @overload
+ def __getitem__(self, item: _SliceOrTuple) -> _SliceOrTuple: ...
+
+index_exp: IndexExpression[Literal[True]]
+s_: IndexExpression[Literal[False]]
+
+def fill_diagonal(a: ndarray[Any, Any], val: Any, wrap: bool = ...) -> None: ...
+def diag_indices(n: int, ndim: int = ...) -> Tuple[_ArrayND[int_], ...]: ...
+def diag_indices_from(arr: ArrayLike) -> Tuple[_ArrayND[int_], ...]: ...
+
+# NOTE: see `numpy/__init__.pyi` for `ndenumerate` and `ndindex`