diff options
Diffstat (limited to 'numpy/typing/tests/test_typing.py')
-rw-r--r-- | numpy/typing/tests/test_typing.py | 104 |
1 files changed, 79 insertions, 25 deletions
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py index f303ebea3..fe58a8f4c 100644 --- a/numpy/typing/tests/test_typing.py +++ b/numpy/typing/tests/test_typing.py @@ -11,6 +11,7 @@ from typing import IO, TYPE_CHECKING import pytest import numpy as np +import numpy.typing as npt from numpy.typing.mypy_plugin import ( _PRECISION_DICT, _EXTENDED_PRECISION_LIST, @@ -57,6 +58,11 @@ def _strip_filename(msg: str) -> str: return tail.split(":", 1)[-1] +def strip_func(match: re.Match[str]) -> str: + """`re.sub` helper function for stripping module names.""" + return match.groups()[1] + + @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.fixture(scope="module", autouse=True) @@ -98,16 +104,10 @@ def run_mypy() -> None: def get_test_cases(directory: str) -> Iterator[ParameterSet]: for root, _, files in os.walk(directory): for fname in files: - if os.path.splitext(fname)[-1] in (".pyi", ".py"): + short_fname, ext = os.path.splitext(fname) + if ext in (".pyi", ".py"): fullpath = os.path.join(root, fname) - # Use relative path for nice py.test name - relpath = os.path.relpath(fullpath, start=directory) - - yield pytest.param( - fullpath, - # Manually specify a name for the test - id=relpath, - ) + yield pytest.param(fullpath, id=short_fname) @pytest.mark.slow @@ -156,9 +156,10 @@ def test_fail(path: str) -> None: target_line = lines[lineno - 1] if "# E:" in target_line: - marker = target_line.split("# E:")[-1].strip() - expected_error = errors.get(lineno) - _test_fail(path, marker, expected_error, lineno) + expression, _, marker = target_line.partition(" # E: ") + expected_error = errors[lineno].strip() + marker = marker.strip() + _test_fail(path, expression, marker, expected_error, lineno) else: pytest.fail( f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}" @@ -167,11 +168,13 @@ def test_fail(path: str) -> None: _FAIL_MSG1 = """Extra error at line {} +Expression: {} Extra error: {!r} """ _FAIL_MSG2 = """Error mismatch at line {} +Expression: {} Expected error: {!r} Observed error: {!r} """ @@ -179,14 +182,49 @@ Observed error: {!r} def _test_fail( path: str, + expression: str, error: str, expected_error: None | str, lineno: int, ) -> None: if expected_error is None: - raise AssertionError(_FAIL_MSG1.format(lineno, error)) + raise AssertionError(_FAIL_MSG1.format(lineno, expression, error)) elif error not in expected_error: - raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error)) + raise AssertionError(_FAIL_MSG2.format( + lineno, expression, expected_error, error + )) + + +def _construct_ctypes_dict() -> dict[str, str]: + dct = { + "ubyte": "c_ubyte", + "ushort": "c_ushort", + "uintc": "c_uint", + "uint": "c_ulong", + "ulonglong": "c_ulonglong", + "byte": "c_byte", + "short": "c_short", + "intc": "c_int", + "int_": "c_long", + "longlong": "c_longlong", + "single": "c_float", + "double": "c_double", + "longdouble": "c_longdouble", + } + + # Match `ctypes` names to the first ctypes type with a given kind and + # precision, e.g. {"c_double": "c_double", "c_longdouble": "c_double"} + # if both types represent 64-bit floats. + # In this context "first" is defined by the order of `dct` + ret = {} + visited: dict[tuple[str, int], str] = {} + for np_name, ct_name in dct.items(): + np_scalar = getattr(np, np_name)() + + # Find the first `ctypes` type for a given `kind`/`itemsize` combo + key = (np_scalar.dtype.kind, np_scalar.dtype.itemsize) + ret[ct_name] = visited.setdefault(key, f"ctypes.{ct_name}") + return ret def _construct_format_dict() -> dict[str, str]: @@ -263,9 +301,10 @@ def _construct_format_dict() -> dict[str, str]: #: A dictionary with all supported format keys (as keys) #: and matching values FORMAT_DICT: dict[str, str] = _construct_format_dict() +FORMAT_DICT.update(_construct_ctypes_dict()) -def _parse_reveals(file: IO[str]) -> list[str]: +def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]: """Extract and parse all ``" # E: "`` comments from the passed file-like object. @@ -275,8 +314,10 @@ def _parse_reveals(file: IO[str]) -> list[str]: """ string = file.read().replace("*", "") - # Grab all `# E:`-based comments - comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2] + # Grab all `# E:`-based comments and matching expressions + expression_array, _, comments_array = np.char.partition( + string.split("\n"), sep=" # E: " + ).T comments = "/n".join(comments_array) # Only search for the `{*}` pattern within comments, otherwise @@ -288,7 +329,7 @@ def _parse_reveals(file: IO[str]) -> list[str]: } fmt_str = comments.format(**kwargs) - return fmt_str.split("/n") + return expression_array, fmt_str.split("/n") @pytest.mark.slow @@ -301,7 +342,7 @@ def test_reveal(path: str) -> None: __tracebackhide__ = True with open(path) as fin: - lines = _parse_reveals(fin) + expression_array, reveal_list = _parse_reveals(fin) output_mypy = OUTPUT_MYPY assert path in output_mypy @@ -316,12 +357,14 @@ def test_reveal(path: str) -> None: lineno = int(match.group('lineno')) - 1 assert "Revealed type is" in error_line - marker = lines[lineno] - _test_reveal(path, marker, error_line, 1 + lineno) + marker = reveal_list[lineno] + expression = expression_array[lineno] + _test_reveal(path, expression, marker, error_line, 1 + lineno) _REVEAL_MSG = """Reveal mismatch at line {} +Expression: {} Expected reveal: {!r} Observed reveal: {!r} """ @@ -329,14 +372,21 @@ Observed reveal: {!r} def _test_reveal( path: str, + expression: str, reveal: str, expected_reveal: str, lineno: int, ) -> None: """Error-reporting helper function for `test_reveal`.""" - if reveal not in expected_reveal: + strip_pattern = re.compile(r"(\w+\.)+(\w+)") + stripped_reveal = strip_pattern.sub(strip_func, reveal) + stripped_expected_reveal = strip_pattern.sub(strip_func, expected_reveal) + if stripped_reveal not in stripped_expected_reveal: raise AssertionError( - _REVEAL_MSG.format(lineno, expected_reveal, reveal) + _REVEAL_MSG.format(lineno, + expression, + stripped_expected_reveal, + stripped_reveal) ) @@ -381,11 +431,15 @@ def test_extended_precision() -> None: output_mypy = OUTPUT_MYPY assert path in output_mypy + with open(path, "r") as f: + expression_list = f.readlines() + for _msg in output_mypy[path]: *_, _lineno, msg_typ, msg = _msg.split(":") msg = _strip_filename(msg) lineno = int(_lineno) + expression = expression_list[lineno - 1].rstrip("\n") msg_typ = msg_typ.strip() assert msg_typ in {"error", "note"} @@ -394,8 +448,8 @@ def test_extended_precision() -> None: raise ValueError(f"Unexpected reveal line format: {lineno}") else: marker = FORMAT_DICT[LINENO_MAPPING[lineno]] - _test_reveal(path, marker, msg, lineno) + _test_reveal(path, expression, marker, msg, lineno) else: if msg_typ == "error": marker = "Module has no attribute" - _test_fail(path, marker, msg, lineno) + _test_fail(path, expression, marker, msg, lineno) |