summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-05-08 20:51:49 +0200
committerGitHub <noreply@github.com>2021-05-08 20:51:49 +0200
commite3583316cab5e71f2b361c32a3eee748905f40c5 (patch)
tree7c68f9504341f5026985149fec7450b719e7e508 /numpy
parentd490589e01a5d232d8ed0b8f37ce8bffc1f21ec6 (diff)
parente377d071ea502f396a7da299633bad74922c04eb (diff)
downloadnumpy-e3583316cab5e71f2b361c32a3eee748905f40c5.tar.gz
Merge pull request #18944 from BvB93/utils
ENH: Add annotations for `np.lib.utils`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/utils.pyi100
-rw-r--r--numpy/typing/tests/data/fail/lib_utils.py13
-rw-r--r--numpy/typing/tests/data/pass/lib_utils.py26
-rw-r--r--numpy/typing/tests/data/reveal/lib_utils.py30
4 files changed, 159 insertions, 10 deletions
diff --git a/numpy/lib/utils.pyi b/numpy/lib/utils.pyi
index 5a1594149..0518655c6 100644
--- a/numpy/lib/utils.pyi
+++ b/numpy/lib/utils.pyi
@@ -1,4 +1,19 @@
-from typing import List
+import sys
+from ast import AST
+from typing import (
+ Any,
+ Callable,
+ List,
+ Mapping,
+ Optional,
+ overload,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+from numpy import ndarray, generic
from numpy.core.numerictypes import (
issubclass_ as issubclass_,
@@ -6,14 +21,79 @@ from numpy.core.numerictypes import (
issubsctype as issubsctype,
)
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+_T_contra = TypeVar("_T_contra", contravariant=True)
+_FuncType = TypeVar("_FuncType", bound=Callable[..., Any])
+
+# A file-like object opened in `w` mode
+class _SupportsWrite(Protocol[_T_contra]):
+ def write(self, __s: _T_contra) -> Any: ...
+
__all__: List[str]
-def get_include(): ...
-def deprecate(*args, **kwargs): ...
-def deprecate_with_doc(msg): ...
-def byte_bounds(a): ...
-def who(vardict=...): ...
-def info(object=..., maxwidth=..., output=..., toplevel=...): ...
-def source(object, output=...): ...
-def lookfor(what, module=..., import_modules=..., regenerate=..., output=...): ...
-def safe_eval(source): ...
+class _Deprecate:
+ old_name: Optional[str]
+ new_name: Optional[str]
+ message: Optional[str]
+ def __init__(
+ self,
+ old_name: Optional[str] = ...,
+ new_name: Optional[str] = ...,
+ message: Optional[str] = ...,
+ ) -> None: ...
+ # NOTE: `__call__` can in principle take arbitrary `*args` and `**kwargs`,
+ # even though they aren't used for anything
+ def __call__(self, func: _FuncType) -> _FuncType: ...
+
+def get_include() -> str: ...
+
+@overload
+def deprecate(
+ *,
+ old_name: Optional[str] = ...,
+ new_name: Optional[str] = ...,
+ message: Optional[str] = ...,
+) -> _Deprecate: ...
+@overload
+def deprecate(
+ __func: _FuncType,
+ old_name: Optional[str] = ...,
+ new_name: Optional[str] = ...,
+ message: Optional[str] = ...,
+) -> _FuncType: ...
+
+def deprecate_with_doc(msg: Optional[str]) -> _Deprecate: ...
+
+# NOTE: In practice `byte_bounds` can (potentially) take any object
+# implementing the `__array_interface__` protocol. The caveat is
+# that certain keys, marked as optional in the spec, must be present for
+# `byte_bounds`. This concerns `"strides"` and `"data"`.
+def byte_bounds(a: Union[generic, ndarray[Any, Any]]) -> Tuple[int, int]: ...
+
+def who(vardict: Optional[Mapping[str, ndarray[Any, Any]]] = ...) -> None: ...
+
+def info(
+ object: object = ...,
+ maxwidth: int = ...,
+ output: Optional[_SupportsWrite[str]] = ...,
+ toplevel: str = ...,
+) -> None: ...
+
+def source(
+ object: object,
+ output: Optional[_SupportsWrite[str]] = ...,
+) -> None: ...
+
+def lookfor(
+ what: str,
+ module: Union[None, str, Sequence[str]] = ...,
+ import_modules: bool = ...,
+ regenerate: bool = ...,
+ output: Optional[_SupportsWrite[str]] =...,
+) -> None: ...
+
+def safe_eval(source: Union[str, AST]) -> Any: ...
diff --git a/numpy/typing/tests/data/fail/lib_utils.py b/numpy/typing/tests/data/fail/lib_utils.py
new file mode 100644
index 000000000..e16c926aa
--- /dev/null
+++ b/numpy/typing/tests/data/fail/lib_utils.py
@@ -0,0 +1,13 @@
+import numpy as np
+
+np.deprecate(1) # E: No overload variant
+
+np.deprecate_with_doc(1) # E: incompatible type
+
+np.byte_bounds(1) # E: incompatible type
+
+np.who(1) # E: incompatible type
+
+np.lookfor(None) # E: incompatible type
+
+np.safe_eval(None) # E: incompatible type
diff --git a/numpy/typing/tests/data/pass/lib_utils.py b/numpy/typing/tests/data/pass/lib_utils.py
new file mode 100644
index 000000000..c602923d9
--- /dev/null
+++ b/numpy/typing/tests/data/pass/lib_utils.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+from io import StringIO
+from typing import Any
+
+import numpy as np
+
+FILE = StringIO()
+AR: np.ndarray[Any, np.dtype[np.float64]] = np.arange(10).astype(np.float64)
+
+def func(a: int) -> bool: ...
+
+np.deprecate(func)
+np.deprecate()
+
+np.deprecate_with_doc("test")
+np.deprecate_with_doc(None)
+
+np.byte_bounds(AR)
+np.byte_bounds(np.float64())
+
+np.info(1, output=FILE)
+
+np.source(np.interp, output=FILE)
+
+np.lookfor("binary representation", output=FILE)
diff --git a/numpy/typing/tests/data/reveal/lib_utils.py b/numpy/typing/tests/data/reveal/lib_utils.py
new file mode 100644
index 000000000..d82012707
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/lib_utils.py
@@ -0,0 +1,30 @@
+from io import StringIO
+from typing import Any, Dict
+
+import numpy as np
+
+AR: np.ndarray[Any, np.dtype[np.float64]]
+AR_DICT: Dict[str, np.ndarray[Any, np.dtype[np.float64]]]
+FILE: StringIO
+
+def func(a: int) -> bool: ...
+
+reveal_type(np.deprecate(func)) # E: def (a: builtins.int) -> builtins.bool
+reveal_type(np.deprecate()) # E: _Deprecate
+
+reveal_type(np.deprecate_with_doc("test")) # E: _Deprecate
+reveal_type(np.deprecate_with_doc(None)) # E: _Deprecate
+
+reveal_type(np.byte_bounds(AR)) # E: Tuple[builtins.int, builtins.int]
+reveal_type(np.byte_bounds(np.float64())) # E: Tuple[builtins.int, builtins.int]
+
+reveal_type(np.who(None)) # E: None
+reveal_type(np.who(AR_DICT)) # E: None
+
+reveal_type(np.info(1, output=FILE)) # E: None
+
+reveal_type(np.source(np.interp, output=FILE)) # E: None
+
+reveal_type(np.lookfor("binary representation", output=FILE)) # E: None
+
+reveal_type(np.safe_eval("1 + 1")) # E: Any