summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/typing/tests/test_typing.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index 4dd6530ff..71e459df6 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -11,6 +11,7 @@ from typing import IO, TYPE_CHECKING
import pytest
import numpy as np
+import numpy.typing as npt
from numpy.typing.mypy_plugin import (
_PRECISION_DICT,
_EXTENDED_PRECISION_LIST,
@@ -150,9 +151,9 @@ def test_fail(path: str) -> None:
target_line = lines[lineno - 1]
if "# E:" in target_line:
- marker = target_line.split("# E:")[-1].strip()
- expected_error = errors.get(lineno)
- _test_fail(path, marker, expected_error, lineno)
+ expression, _, marker = target_line.partition(" # E: ")
+ expected_error = errors[lineno]
+ _test_fail(path, expression, marker.strip(), expected_error.strip(), lineno)
else:
pytest.fail(
f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}"
@@ -161,11 +162,13 @@ def test_fail(path: str) -> None:
_FAIL_MSG1 = """Extra error at line {}
+Expression: {}
Extra error: {!r}
"""
_FAIL_MSG2 = """Error mismatch at line {}
+Expression: {}
Expected error: {!r}
Observed error: {!r}
"""
@@ -173,14 +176,15 @@ Observed error: {!r}
def _test_fail(
path: str,
+ expression: str,
error: str,
expected_error: None | str,
lineno: int,
) -> None:
if expected_error is None:
- raise AssertionError(_FAIL_MSG1.format(lineno, error))
+ raise AssertionError(_FAIL_MSG1.format(lineno, expression, error))
elif error not in expected_error:
- raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))
+ raise AssertionError(_FAIL_MSG2.format(lineno, expression, expected_error, error))
def _construct_format_dict() -> dict[str, str]:
@@ -259,7 +263,7 @@ def _construct_format_dict() -> dict[str, str]:
FORMAT_DICT: dict[str, str] = _construct_format_dict()
-def _parse_reveals(file: IO[str]) -> list[str]:
+def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]:
"""Extract and parse all ``" # E: "`` comments from the passed
file-like object.
@@ -269,8 +273,10 @@ def _parse_reveals(file: IO[str]) -> list[str]:
"""
string = file.read().replace("*", "")
- # Grab all `# E:`-based comments
- comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2]
+ # Grab all `# E:`-based comments and matching expressions
+ expression_array, _, comments_array = np.char.partition(
+ string.split("\n"), sep=" # E: "
+ ).T
comments = "/n".join(comments_array)
# Only search for the `{*}` pattern within comments, otherwise
@@ -282,7 +288,7 @@ def _parse_reveals(file: IO[str]) -> list[str]:
}
fmt_str = comments.format(**kwargs)
- return fmt_str.split("/n")
+ return expression_array, fmt_str.split("/n")
@pytest.mark.slow
@@ -295,7 +301,7 @@ def test_reveal(path: str) -> None:
__tracebackhide__ = True
with open(path) as fin:
- lines = _parse_reveals(fin)
+ expression_array, reveal_list = _parse_reveals(fin)
output_mypy = OUTPUT_MYPY
assert path in output_mypy
@@ -310,12 +316,14 @@ def test_reveal(path: str) -> None:
lineno = int(match.group('lineno')) - 1
assert "Revealed type is" in error_line
- marker = lines[lineno]
- _test_reveal(path, marker, error_line, 1 + lineno)
+ marker = reveal_list[lineno]
+ expression = expression_array[lineno]
+ _test_reveal(path, expression, marker, error_line, 1 + lineno)
_REVEAL_MSG = """Reveal mismatch at line {}
+Expression: {}
Expected reveal: {!r}
Observed reveal: {!r}
"""
@@ -323,6 +331,7 @@ Observed reveal: {!r}
def _test_reveal(
path: str,
+ expression: str,
reveal: str,
expected_reveal: str,
lineno: int,
@@ -330,7 +339,7 @@ def _test_reveal(
"""Error-reporting helper function for `test_reveal`."""
if reveal not in expected_reveal:
raise AssertionError(
- _REVEAL_MSG.format(lineno, expected_reveal, reveal)
+ _REVEAL_MSG.format(lineno, expression, expected_reveal, reveal)
)
@@ -375,11 +384,15 @@ def test_extended_precision() -> None:
output_mypy = OUTPUT_MYPY
assert path in output_mypy
+ with open(path, "r") as f:
+ expression_list = f.readlines()
+
for _msg in output_mypy[path]:
*_, _lineno, msg_typ, msg = _msg.split(":")
msg = _strip_filename(msg)
lineno = int(_lineno)
+ expression = expression_list[lineno - 1].rstrip("\n")
msg_typ = msg_typ.strip()
assert msg_typ in {"error", "note"}
@@ -388,8 +401,8 @@ def test_extended_precision() -> None:
raise ValueError(f"Unexpected reveal line format: {lineno}")
else:
marker = FORMAT_DICT[LINENO_MAPPING[lineno]]
- _test_reveal(path, marker, msg, lineno)
+ _test_reveal(path, expression, marker, msg, lineno)
else:
if msg_typ == "error":
marker = "Module has no attribute"
- _test_fail(path, marker, msg, lineno)
+ _test_fail(path, expression, marker, msg, lineno)