summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-08-26 23:34:28 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2020-10-26 22:14:52 +0100
commit947fea4c2f8db6cd9f8628b1b34c8697b89eaf82 (patch)
tree564437b1bdf085e1d54f1e50d9830fb1cb13436e /numpy
parentbffc8b422d7741ae47d0c5604bda475d3c92aad6 (diff)
downloadnumpy-947fea4c2f8db6cd9f8628b1b34c8697b89eaf82.tar.gz
ENH: Add annotations for `np.core.shape_base`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi17
-rw-r--r--numpy/core/shape_base.pyi41
2 files changed, 51 insertions, 7 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index ce4aa1ae6..339a651c0 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -193,6 +193,16 @@ from numpy.core.numerictypes import (
find_common_type as find_common_type,
)
+from numpy.core.shape_base import (
+ atleast_1d as atleast_1d,
+ atleast_2d as atleast_2d,
+ atleast_3d as atleast_3d,
+ block as block,
+ hstack as hstack,
+ stack as stack,
+ vstack as vstack,
+)
+
# Add an object to `__all__` if their stubs are defined in an external file;
# their stubs will not be recognized otherwise.
# NOTE: This is redundant for objects defined within this file.
@@ -256,15 +266,11 @@ asarray_chkfinite: Any
asfarray: Any
asmatrix: Any
asscalar: Any
-atleast_1d: Any
-atleast_2d: Any
-atleast_3d: Any
average: Any
bartlett: Any
bincount: Any
bitwise_not: Any
blackman: Any
-block: Any
bmat: Any
bool8: Any
broadcast: Any
@@ -354,7 +360,6 @@ histogram2d: Any
histogram_bin_edges: Any
histogramdd: Any
hsplit: Any
-hstack: Any
i0: Any
iinfo: Any
imag: Any
@@ -488,7 +493,6 @@ singlecomplex: Any
sort_complex: Any
source: Any
split: Any
-stack: Any
string_: Any
take_along_axis: Any
tile: Any
@@ -521,7 +525,6 @@ vdot: Any
vectorize: Any
void0: Any
vsplit: Any
-vstack: Any
where: Any
who: Any
diff --git a/numpy/core/shape_base.pyi b/numpy/core/shape_base.pyi
new file mode 100644
index 000000000..2fe945f3b
--- /dev/null
+++ b/numpy/core/shape_base.pyi
@@ -0,0 +1,41 @@
+import sys
+from typing import TypeVar, overload, List, Sequence, Optional
+
+from numpy import ndarray
+from numpy.typing import ArrayLike
+
+if sys.version_info >= (3, 8):
+ from typing import SupportsIndex, Literal
+else:
+ from typing_extensions import Literal, Protocol
+ class SupportsIndex(Protocol):
+ def __index__(self) -> int: ...
+
+_ArrayType = TypeVar("_ArrayType", bound=ndarray)
+
+@overload
+def atleast_1d(__arys: ArrayLike) -> ndarray: ...
+@overload
+def atleast_1d(*arys: ArrayLike) -> List[ndarray]: ...
+
+@overload
+def atleast_2d(__arys: ArrayLike) -> ndarray: ...
+@overload
+def atleast_2d(*arys: ArrayLike) -> List[ndarray]: ...
+
+@overload
+def atleast_3d(__arys: ArrayLike) -> ndarray: ...
+@overload
+def atleast_3d(*arys: ArrayLike) -> List[ndarray]: ...
+
+def vstack(tup: Sequence[ArrayLike]) -> ndarray: ...
+def hstack(tup: Sequence[ArrayLike]) -> ndarray: ...
+@overload
+def stack(
+ arrays: Sequence[ArrayLike], axis: SupportsIndex = ..., out: None = ...
+) -> ndarray: ...
+@overload
+def stack(
+ arrays: Sequence[ArrayLike], axis: SupportsIndex = ..., out: _ArrayType = ...
+) -> _ArrayType: ...
+def block(arrays: ArrayLike) -> ndarray: ...