summaryrefslogtreecommitdiff
path: root/numpy/typing/tests/test_typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/tests/test_typing.py')
-rw-r--r--numpy/typing/tests/test_typing.py104
1 files changed, 79 insertions, 25 deletions
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index f303ebea3..fe58a8f4c 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -11,6 +11,7 @@ from typing import IO, TYPE_CHECKING
import pytest
import numpy as np
+import numpy.typing as npt
from numpy.typing.mypy_plugin import (
_PRECISION_DICT,
_EXTENDED_PRECISION_LIST,
@@ -57,6 +58,11 @@ def _strip_filename(msg: str) -> str:
return tail.split(":", 1)[-1]
+def strip_func(match: re.Match[str]) -> str:
+ """`re.sub` helper function for stripping module names."""
+ return match.groups()[1]
+
+
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.fixture(scope="module", autouse=True)
@@ -98,16 +104,10 @@ def run_mypy() -> None:
def get_test_cases(directory: str) -> Iterator[ParameterSet]:
for root, _, files in os.walk(directory):
for fname in files:
- if os.path.splitext(fname)[-1] in (".pyi", ".py"):
+ short_fname, ext = os.path.splitext(fname)
+ if ext in (".pyi", ".py"):
fullpath = os.path.join(root, fname)
- # Use relative path for nice py.test name
- relpath = os.path.relpath(fullpath, start=directory)
-
- yield pytest.param(
- fullpath,
- # Manually specify a name for the test
- id=relpath,
- )
+ yield pytest.param(fullpath, id=short_fname)
@pytest.mark.slow
@@ -156,9 +156,10 @@ def test_fail(path: str) -> None:
target_line = lines[lineno - 1]
if "# E:" in target_line:
- marker = target_line.split("# E:")[-1].strip()
- expected_error = errors.get(lineno)
- _test_fail(path, marker, expected_error, lineno)
+ expression, _, marker = target_line.partition(" # E: ")
+ expected_error = errors[lineno].strip()
+ marker = marker.strip()
+ _test_fail(path, expression, marker, expected_error, lineno)
else:
pytest.fail(
f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}"
@@ -167,11 +168,13 @@ def test_fail(path: str) -> None:
_FAIL_MSG1 = """Extra error at line {}
+Expression: {}
Extra error: {!r}
"""
_FAIL_MSG2 = """Error mismatch at line {}
+Expression: {}
Expected error: {!r}
Observed error: {!r}
"""
@@ -179,14 +182,49 @@ Observed error: {!r}
def _test_fail(
path: str,
+ expression: str,
error: str,
expected_error: None | str,
lineno: int,
) -> None:
if expected_error is None:
- raise AssertionError(_FAIL_MSG1.format(lineno, error))
+ raise AssertionError(_FAIL_MSG1.format(lineno, expression, error))
elif error not in expected_error:
- raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))
+ raise AssertionError(_FAIL_MSG2.format(
+ lineno, expression, expected_error, error
+ ))
+
+
+def _construct_ctypes_dict() -> dict[str, str]:
+ dct = {
+ "ubyte": "c_ubyte",
+ "ushort": "c_ushort",
+ "uintc": "c_uint",
+ "uint": "c_ulong",
+ "ulonglong": "c_ulonglong",
+ "byte": "c_byte",
+ "short": "c_short",
+ "intc": "c_int",
+ "int_": "c_long",
+ "longlong": "c_longlong",
+ "single": "c_float",
+ "double": "c_double",
+ "longdouble": "c_longdouble",
+ }
+
+ # Match `ctypes` names to the first ctypes type with a given kind and
+ # precision, e.g. {"c_double": "c_double", "c_longdouble": "c_double"}
+ # if both types represent 64-bit floats.
+ # In this context "first" is defined by the order of `dct`
+ ret = {}
+ visited: dict[tuple[str, int], str] = {}
+ for np_name, ct_name in dct.items():
+ np_scalar = getattr(np, np_name)()
+
+ # Find the first `ctypes` type for a given `kind`/`itemsize` combo
+ key = (np_scalar.dtype.kind, np_scalar.dtype.itemsize)
+ ret[ct_name] = visited.setdefault(key, f"ctypes.{ct_name}")
+ return ret
def _construct_format_dict() -> dict[str, str]:
@@ -263,9 +301,10 @@ def _construct_format_dict() -> dict[str, str]:
#: A dictionary with all supported format keys (as keys)
#: and matching values
FORMAT_DICT: dict[str, str] = _construct_format_dict()
+FORMAT_DICT.update(_construct_ctypes_dict())
-def _parse_reveals(file: IO[str]) -> list[str]:
+def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]:
"""Extract and parse all ``" # E: "`` comments from the passed
file-like object.
@@ -275,8 +314,10 @@ def _parse_reveals(file: IO[str]) -> list[str]:
"""
string = file.read().replace("*", "")
- # Grab all `# E:`-based comments
- comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2]
+ # Grab all `# E:`-based comments and matching expressions
+ expression_array, _, comments_array = np.char.partition(
+ string.split("\n"), sep=" # E: "
+ ).T
comments = "/n".join(comments_array)
# Only search for the `{*}` pattern within comments, otherwise
@@ -288,7 +329,7 @@ def _parse_reveals(file: IO[str]) -> list[str]:
}
fmt_str = comments.format(**kwargs)
- return fmt_str.split("/n")
+ return expression_array, fmt_str.split("/n")
@pytest.mark.slow
@@ -301,7 +342,7 @@ def test_reveal(path: str) -> None:
__tracebackhide__ = True
with open(path) as fin:
- lines = _parse_reveals(fin)
+ expression_array, reveal_list = _parse_reveals(fin)
output_mypy = OUTPUT_MYPY
assert path in output_mypy
@@ -316,12 +357,14 @@ def test_reveal(path: str) -> None:
lineno = int(match.group('lineno')) - 1
assert "Revealed type is" in error_line
- marker = lines[lineno]
- _test_reveal(path, marker, error_line, 1 + lineno)
+ marker = reveal_list[lineno]
+ expression = expression_array[lineno]
+ _test_reveal(path, expression, marker, error_line, 1 + lineno)
_REVEAL_MSG = """Reveal mismatch at line {}
+Expression: {}
Expected reveal: {!r}
Observed reveal: {!r}
"""
@@ -329,14 +372,21 @@ Observed reveal: {!r}
def _test_reveal(
path: str,
+ expression: str,
reveal: str,
expected_reveal: str,
lineno: int,
) -> None:
"""Error-reporting helper function for `test_reveal`."""
- if reveal not in expected_reveal:
+ strip_pattern = re.compile(r"(\w+\.)+(\w+)")
+ stripped_reveal = strip_pattern.sub(strip_func, reveal)
+ stripped_expected_reveal = strip_pattern.sub(strip_func, expected_reveal)
+ if stripped_reveal not in stripped_expected_reveal:
raise AssertionError(
- _REVEAL_MSG.format(lineno, expected_reveal, reveal)
+ _REVEAL_MSG.format(lineno,
+ expression,
+ stripped_expected_reveal,
+ stripped_reveal)
)
@@ -381,11 +431,15 @@ def test_extended_precision() -> None:
output_mypy = OUTPUT_MYPY
assert path in output_mypy
+ with open(path, "r") as f:
+ expression_list = f.readlines()
+
for _msg in output_mypy[path]:
*_, _lineno, msg_typ, msg = _msg.split(":")
msg = _strip_filename(msg)
lineno = int(_lineno)
+ expression = expression_list[lineno - 1].rstrip("\n")
msg_typ = msg_typ.strip()
assert msg_typ in {"error", "note"}
@@ -394,8 +448,8 @@ def test_extended_precision() -> None:
raise ValueError(f"Unexpected reveal line format: {lineno}")
else:
marker = FORMAT_DICT[LINENO_MAPPING[lineno]]
- _test_reveal(path, marker, msg, lineno)
+ _test_reveal(path, expression, marker, msg, lineno)
else:
if msg_typ == "error":
marker = "Module has no attribute"
- _test_fail(path, marker, msg, lineno)
+ _test_fail(path, expression, marker, msg, lineno)