summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-01-24 13:04:44 -0700
committerGitHub <noreply@github.com>2022-01-24 13:04:44 -0700
commit986a8797b9bf7ccb499861bb46ec8ce13fb720ec (patch)
tree2c00d456bc387b54df3c57e4930c43aadd0240c7 /numpy
parent932202d24c399f46161caa7464446b55e27fa947 (diff)
parent2906e917e6befb68b02eeaec78a7a2a024073686 (diff)
downloadnumpy-986a8797b9bf7ccb499861bb46ec8ce13fb720ec.tar.gz
Merge pull request #20885 from BvB93/param_spec
TYP,ENH: Improve typing with the help of `ParamSpec`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/function_base.pyi2
-rw-r--r--numpy/lib/shape_base.pyi2
-rw-r--r--numpy/testing/_private/utils.pyi32
-rw-r--r--numpy/typing/tests/data/fail/testing.pyi2
-rw-r--r--numpy/typing/tests/data/reveal/testing.pyi4
5 files changed, 27 insertions, 15 deletions
diff --git a/numpy/lib/function_base.pyi b/numpy/lib/function_base.pyi
index 6e2f886cf..3b40d3f1c 100644
--- a/numpy/lib/function_base.pyi
+++ b/numpy/lib/function_base.pyi
@@ -195,6 +195,8 @@ def asarray_chkfinite(
order: _OrderKACF = ...,
) -> NDArray[Any]: ...
+# TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
+# xref python/mypy#8645
@overload
def piecewise(
x: _ArrayLike[_SCT],
diff --git a/numpy/lib/shape_base.pyi b/numpy/lib/shape_base.pyi
index 82702e67c..f8f86128c 100644
--- a/numpy/lib/shape_base.pyi
+++ b/numpy/lib/shape_base.pyi
@@ -77,6 +77,8 @@ def put_along_axis(
axis: None | int,
) -> None: ...
+# TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
+# xref python/mypy#8645
@overload
def apply_along_axis(
func1d: Callable[..., _ArrayLike[_SCT]],
diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi
index 8117f18ae..f4b22834d 100644
--- a/numpy/testing/_private/utils.pyi
+++ b/numpy/testing/_private/utils.pyi
@@ -20,6 +20,7 @@ from typing import (
Final,
SupportsIndex,
)
+from typing_extensions import ParamSpec
from numpy import generic, dtype, number, object_, bool_, _FloatValue
from numpy.typing import (
@@ -36,6 +37,7 @@ from unittest.case import (
SkipTest as SkipTest,
)
+_P = ParamSpec("_P")
_T = TypeVar("_T")
_ET = TypeVar("_ET", bound=BaseException)
_FT = TypeVar("_FT", bound=Callable[..., Any])
@@ -254,10 +256,10 @@ def raises(*args: type[BaseException]) -> Callable[[_FT], _FT]: ...
@overload
def assert_raises( # type: ignore
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
- callable: Callable[..., Any],
+ callable: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises(
@@ -270,10 +272,10 @@ def assert_raises(
def assert_raises_regex(
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
expected_regex: str | bytes | Pattern[Any],
- callable: Callable[..., Any],
+ callable: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises_regex(
@@ -336,20 +338,20 @@ def assert_warns(
@overload
def assert_warns(
warning_class: type[Warning],
- func: Callable[..., _T],
+ func: Callable[_P, _T],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> _T: ...
@overload
def assert_no_warnings() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_warnings(
- func: Callable[..., _T],
+ func: Callable[_P, _T],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> _T: ...
@overload
@@ -384,10 +386,10 @@ def temppath(
def assert_no_gc_cycles() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_gc_cycles(
- func: Callable[..., Any],
+ func: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
def break_cycles() -> None: ...
diff --git a/numpy/typing/tests/data/fail/testing.pyi b/numpy/typing/tests/data/fail/testing.pyi
index e753a9810..803870e2f 100644
--- a/numpy/typing/tests/data/fail/testing.pyi
+++ b/numpy/typing/tests/data/fail/testing.pyi
@@ -22,5 +22,7 @@ np.testing.assert_array_max_ulp(AR_U, AR_U) # E: incompatible type
np.testing.assert_warns(warning_class=RuntimeWarning, func=func) # E: No overload variant
np.testing.assert_no_warnings(func=func) # E: No overload variant
+np.testing.assert_no_warnings(func, None) # E: Too many arguments
+np.testing.assert_no_warnings(func, test=None) # E: Unexpected keyword argument
np.testing.assert_no_gc_cycles(func=func) # E: No overload variant
diff --git a/numpy/typing/tests/data/reveal/testing.pyi b/numpy/typing/tests/data/reveal/testing.pyi
index fb419d48d..edd4bb3bf 100644
--- a/numpy/typing/tests/data/reveal/testing.pyi
+++ b/numpy/typing/tests/data/reveal/testing.pyi
@@ -154,8 +154,12 @@ reveal_type(np.testing.assert_array_max_ulp(AR_i8, AR_f8, dtype=np.float32)) #
reveal_type(np.testing.assert_warns(RuntimeWarning)) # E: _GeneratorContextManager[None]
reveal_type(np.testing.assert_warns(RuntimeWarning, func3, 5)) # E: bool
+def func4(a: int, b: str) -> bool: ...
+
reveal_type(np.testing.assert_no_warnings()) # E: _GeneratorContextManager[None]
reveal_type(np.testing.assert_no_warnings(func3, 5)) # E: bool
+reveal_type(np.testing.assert_no_warnings(func4, a=1, b="test")) # E: bool
+reveal_type(np.testing.assert_no_warnings(func4, 1, "test")) # E: bool
reveal_type(np.testing.tempdir("test_dir")) # E: _GeneratorContextManager[builtins.str]
reveal_type(np.testing.tempdir(prefix=b"test")) # E: _GeneratorContextManager[builtins.bytes]