summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-06-02 02:58:54 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-06-06 17:58:27 +0200
commit085ccdc4876f08b742fcf1fdf79274c24c649e92 (patch)
tree46b109793f97a435d670424ebe3c2c1ec4796f30 /numpy
parent3c8e5da46c1440ea0eb3fe66fdc897d3fbae194d (diff)
downloadnumpy-085ccdc4876f08b742fcf1fdf79274c24c649e92.tar.gz
ENH: Add annotations for `np.testing`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/__init__.pyi5
-rw-r--r--numpy/testing/_private/utils.pyi412
2 files changed, 337 insertions, 80 deletions
diff --git a/numpy/testing/__init__.pyi b/numpy/testing/__init__.pyi
index f40c06e9a..955dae862 100644
--- a/numpy/testing/__init__.pyi
+++ b/numpy/testing/__init__.pyi
@@ -47,4 +47,7 @@ from numpy.testing._private.utils import (
__all__: List[str]
-def run_module_suite(file_to_run=..., argv=...): ...
+def run_module_suite(
+ file_to_run: None | str = ...,
+ argv: None | List[str] = ...,
+) -> None: ...
diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi
index efb6bcd91..a297b3379 100644
--- a/numpy/testing/_private/utils.pyi
+++ b/numpy/testing/_private/utils.pyi
@@ -1,107 +1,361 @@
+import os
import sys
+import ast
+import types
import warnings
-from typing import Any, List, ClassVar, Tuple, Set
+import unittest
+import contextlib
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ NoReturn,
+ overload,
+ Pattern,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ type_check_only,
+ TypeVar,
+ Union,
+)
-if sys.version_info >= (3, 8):
- from typing import Final
-else:
- from typing_extensions import Final
+from numpy import generic, dtype, number, object_, bool_
+from numpy.typing import NDArray, ArrayLike, DTypeLike
from unittest.case import (
SkipTest as SkipTest,
)
+if sys.version_info >= (3, 8):
+ from typing import Final, SupportsIndex, Literal as L
+else:
+ from typing_extensions import Final, SupportsIndex, Literal as L
+
+_T = TypeVar("_T")
+_ET = TypeVar("_ET", bound=BaseException)
+_FT = TypeVar("_FT", bound=Callable[..., Any])
+
+# Must return a bool or an ndarray/generic type
+# that is supported by `np.logical_and.reduce`
+_ComparisonFunc = Callable[
+ [NDArray[Any], NDArray[Any]],
+ Union[
+ bool,
+ bool_,
+ number[Any],
+ NDArray[Union[bool_, number[Any], object_]],
+ ],
+]
+
__all__: List[str]
class KnownFailureException(Exception): ...
class IgnoreException(Exception): ...
class clear_and_catch_warnings(warnings.catch_warnings):
- class_modules: ClassVar[Tuple[str, ...]]
- modules: Set[str]
- def __init__(self, record=..., modules=...): ...
- def __enter__(self): ...
- def __exit__(self, *exc_info): ...
+ class_modules: ClassVar[Tuple[types.ModuleType, ...]]
+ modules: Set[types.ModuleType]
+ @overload
+ def __new__(
+ cls,
+ record: L[False] = ...,
+ modules: Iterable[types.ModuleType] = ...,
+ ) -> _clear_and_catch_warnings_without_records: ...
+ @overload
+ def __new__(
+ cls,
+ record: L[True],
+ modules: Iterable[types.ModuleType] = ...,
+ ) -> _clear_and_catch_warnings_with_records: ...
+ @overload
+ def __new__(
+ cls,
+ record: bool,
+ modules: Iterable[types.ModuleType] = ...,
+ ) -> clear_and_catch_warnings: ...
+ def __enter__(self) -> None | List[warnings.WarningMessage]: ...
+ def __exit__(
+ self,
+ __exc_type: None | Type[BaseException] = ...,
+ __exc_val: None | BaseException = ...,
+ __exc_tb: None | types.TracebackType = ...,
+ ) -> None: ...
+
+# Type-check only `clear_and_catch_warnings` subclasses for both values of the
+# `record` parameter. Copied from the stdlib `warnings` stubs.
+
+@type_check_only
+class _clear_and_catch_warnings_with_records(clear_and_catch_warnings):
+ def __enter__(self) -> List[warnings.WarningMessage]: ...
+
+@type_check_only
+class _clear_and_catch_warnings_without_records(clear_and_catch_warnings):
+ def __enter__(self) -> None: ...
class suppress_warnings:
log: List[warnings.WarningMessage]
- def __init__(self, forwarding_rule=...): ...
- def filter(self, category=..., message=..., module=...): ...
- def record(self, category=..., message=..., module=...): ...
- def __enter__(self): ...
- def __exit__(self, *exc_info): ...
- def __call__(self, func): ...
+ def __init__(
+ self,
+ forwarding_rule: L["always", "module", "once", "location"] = ...,
+ ) -> None: ...
+ def filter(
+ self,
+ category: Type[Warning] = ...,
+ message: str = ...,
+ module: None | types.ModuleType = ...,
+ ) -> None: ...
+ def record(
+ self,
+ category: Type[Warning] = ...,
+ message: str = ...,
+ module: None | types.ModuleType = ...,
+ ) -> List[warnings.WarningMessage]: ...
+ def __enter__(self: _T) -> _T: ...
+ def __exit__(
+ self,
+ __exc_type: None | Type[BaseException] = ...,
+ __exc_val: None | BaseException = ...,
+ __exc_tb: None | types.TracebackType = ...,
+ ) -> None: ...
+ def __call__(self, func: _FT) -> _FT: ...
verbose: int
IS_PYPY: Final[bool]
HAS_REFCOUNT: Final[bool]
HAS_LAPACK64: Final[bool]
-def assert_(val, msg=...): ...
-def memusage(processName=..., instance=...): ...
-def jiffies(_proc_pid_stat=..., _load_time=...): ...
+def assert_(val: object, msg: str | Callable[[], str] = ...) -> None: ...
+
+# Contrary to runtime we can't do `os.name` checks while type checking,
+# only `sys.platform` checks
+if sys.platform == "win32" or sys.platform == "cygwin":
+ def memusage(processName: str = ..., instance: int = ...) -> int: ...
+elif sys.platform == "linux":
+ def memusage(_proc_pid_stat: str | bytes | os.PathLike[Any] = ...) -> None | int: ...
+else:
+ def memusage() -> NoReturn: ...
+
+if sys.platform == "linux":
+ def jiffies(
+ _proc_pid_stat: str | bytes | os.PathLike[Any] = ...,
+ _load_time: List[float] = ...,
+ ) -> int: ...
+else:
+ def jiffies(_load_time: List[float] = ...) -> int: ...
+
def build_err_msg(
- arrays,
- err_msg,
- header=...,
- verbose=...,
- names=...,
- precision=...,
-): ...
-def assert_equal(actual, desired, err_msg=..., verbose=...): ...
-def print_assert_equal(test_string, actual, desired): ...
+ arrays: Iterable[object],
+ err_msg: str,
+ header: str = ...,
+ verbose: bool = ...,
+ names: Sequence[str] = ...,
+ precision: None | SupportsIndex = ...,
+) -> str: ...
+
+def assert_equal(
+ actual: object,
+ desired: object,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
+def print_assert_equal(
+ test_string: str,
+ actual: object,
+ desired: object,
+) -> None: ...
+
def assert_almost_equal(
- actual,
- desired,
- decimal=...,
- err_msg=...,
- verbose=...,
-): ...
+ actual: ArrayLike,
+ desired: ArrayLike,
+ decimal: int = ...,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
def assert_approx_equal(
- actual,
- desired,
- significant=...,
- err_msg=...,
- verbose=...,
-): ...
+ actual: ArrayLike,
+ desired: ArrayLike,
+ significant: int = ...,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
def assert_array_compare(
- comparison,
- x,
- y,
- err_msg=...,
- verbose=...,
- header=...,
- precision=...,
- equal_nan=...,
- equal_inf=...,
-): ...
-def assert_array_equal(x, y, err_msg=..., verbose=...): ...
-def assert_array_almost_equal(x, y, decimal=..., err_msg=..., verbose=...): ...
-def assert_array_less(x, y, err_msg=..., verbose=...): ...
-def runstring(astr, dict): ...
-def assert_string_equal(actual, desired): ...
-def rundocs(filename=..., raise_on_error=...): ...
-def raises(*args): ...
-def assert_raises(*args, **kwargs): ...
-def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): ...
-def decorate_methods(cls, decorator, testmatch=...): ...
-def measure(code_str, times=..., label=...): ...
+ comparison: _ComparisonFunc,
+ x: ArrayLike,
+ y: ArrayLike,
+ err_msg: str = ...,
+ verbose: bool = ...,
+ header: str = ...,
+ precision: SupportsIndex = ...,
+ equal_nan: bool = ...,
+ equal_inf: bool = ...,
+) -> None: ...
+
+def assert_array_equal(
+ x: ArrayLike,
+ y: ArrayLike,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
+def assert_array_almost_equal(
+ x: ArrayLike,
+ y: ArrayLike,
+ decimal: int = ...,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
+def assert_array_less(
+ x: ArrayLike,
+ y: ArrayLike,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
+def runstring(
+ astr: str | bytes | types.CodeType,
+ dict: None | Dict[str, Any],
+) -> Any: ...
+
+def assert_string_equal(actual: str, desired: str) -> None: ...
+
+def rundocs(
+ filename: None | str | os.PathLike[str] = ...,
+ raise_on_error: bool = ...,
+) -> None: ...
+
+def raises(*args: Type[BaseException]) -> Callable[[_FT], _FT]: ...
+
+@overload
+def assert_raises( # type: ignore
+ __expected_exception: Type[BaseException] | Tuple[Type[BaseException], ...],
+ __callable: Callable[..., Any],
+ *args: Any,
+ **kwargs: Any,
+) -> None: ...
+@overload
+def assert_raises(
+ expected_exception: Type[_ET] | Tuple[Type[_ET], ...],
+ *,
+ msg: None | str = ...,
+) -> unittest.case._AssertRaisesContext[_ET]: ...
+
+@overload
+def assert_raises_regex(
+ __expected_exception: Type[BaseException] | Tuple[Type[BaseException], ...],
+ __expected_regex: str | bytes | Pattern[Any],
+ __callable: Callable[..., Any],
+ *args: Any,
+ **kwargs: Any,
+) -> None: ...
+@overload
+def assert_raises_regex(
+ expected_exception: Type[_ET] | Tuple[Type[_ET], ...],
+ expected_regex: str | bytes | Pattern[Any],
+ *,
+ msg: None | str = ...,
+) -> unittest.case._AssertRaisesContext[_ET]: ...
+
+def decorate_methods(
+ cls: Type[Any],
+ decorator: Callable[[Callable[..., Any]], Any],
+ testmatch: None | str | bytes | Pattern[Any] = ...,
+) -> None: ...
+
+def measure(
+ code_str: str | bytes | ast.mod | ast.AST,
+ times: int = ...,
+ label: None | str = ...,
+) -> float: ...
+
def assert_allclose(
- actual,
- desired,
- rtol=...,
- atol=...,
- equal_nan=...,
- err_msg=...,
- verbose=...,
-): ...
-def assert_array_almost_equal_nulp(x, y, nulp=...): ...
-def assert_array_max_ulp(a, b, maxulp=..., dtype=...): ...
-def assert_warns(warning_class, *args, **kwargs): ...
-def assert_no_warnings(*args, **kwargs): ...
-def tempdir(*args, **kwargs): ...
-def temppath(*args, **kwargs): ...
-def assert_no_gc_cycles(*args, **kwargs): ...
-def break_cycles(): ...
-def _assert_valid_refcount(op): ...
-def _gen_alignment_data(dtype=..., type=..., max_size=...): ...
+ actual: ArrayLike,
+ desired: ArrayLike,
+ rtol: float = ...,
+ atol: float = ...,
+ equal_nan: bool = ...,
+ err_msg: str = ...,
+ verbose: bool = ...,
+) -> None: ...
+
+def assert_array_almost_equal_nulp(
+ x: ArrayLike,
+ y: ArrayLike,
+ nulp: float = ...,
+) -> None: ...
+
+def assert_array_max_ulp(
+ a: ArrayLike,
+ b: ArrayLike,
+ maxulp: float = ...,
+ dtype: DTypeLike = ...,
+) -> NDArray[Any]: ...
+
+@overload
+def assert_warns(
+ warning_class: Type[Warning],
+) -> contextlib._GeneratorContextManager[None]: ...
+@overload
+def assert_warns(
+ __warning_class: Type[Warning],
+ __func: Callable[..., _T],
+ *args: Any,
+ **kwargs: Any,
+) -> _T: ...
+
+@overload
+def assert_no_warnings() -> contextlib._GeneratorContextManager[None]: ...
+@overload
+def assert_no_warnings(
+ __func: Callable[..., _T],
+ *args: Any,
+ **kwargs: Any,
+) -> _T: ...
+
+@overload
+def tempdir(
+ suffix: None = ...,
+ prefix: None = ...,
+ dir: None = ...,
+) -> contextlib._GeneratorContextManager[str]: ...
+@overload
+def tempdir(
+ suffix: None | AnyStr = ...,
+ prefix: None | AnyStr = ...,
+ dir: None | AnyStr | os.PathLike[AnyStr] = ...,
+) -> contextlib._GeneratorContextManager[AnyStr]: ...
+
+@overload
+def temppath(
+ suffix: None = ...,
+ prefix: None = ...,
+ dir: None = ...,
+ text: bool = ...,
+) -> contextlib._GeneratorContextManager[str]: ...
+@overload
+def temppath(
+ suffix: None | AnyStr = ...,
+ prefix: None | AnyStr = ...,
+ dir: None | AnyStr | os.PathLike[AnyStr] = ...,
+ text: bool = ...,
+) -> contextlib._GeneratorContextManager[AnyStr]: ...
+
+@overload
+def assert_no_gc_cycles() -> contextlib._GeneratorContextManager[None]: ...
+@overload
+def assert_no_gc_cycles(
+ __func: Callable[..., Any],
+ *args: Any,
+ **kwargs: Any,
+) -> None: ...
+
+def break_cycles() -> None: ...