diff options
Diffstat (limited to 'numpy/typing/tests/test_typing.py')
-rw-r--r-- | numpy/typing/tests/test_typing.py | 207 |
1 files changed, 151 insertions, 56 deletions
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py index 35558c880..2dcfd6082 100644 --- a/numpy/typing/tests/test_typing.py +++ b/numpy/typing/tests/test_typing.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import importlib.util import itertools import os import re import shutil from collections import defaultdict -from typing import Optional, IO, Dict, List +from collections.abc import Iterator +from typing import IO, TYPE_CHECKING import pytest import numpy as np +import numpy.typing as npt from numpy.typing.mypy_plugin import ( _PRECISION_DICT, _EXTENDED_PRECISION_LIST, @@ -21,6 +25,10 @@ except ImportError: else: NO_MYPY = False +if TYPE_CHECKING: + # We need this as annotation, but it's located in a private namespace. + # As a compromise, do *not* import it during runtime + from _pytest.mark.structures import ParameterSet DATA_DIR = os.path.join(os.path.dirname(__file__), "data") PASS_DIR = os.path.join(DATA_DIR, "pass") @@ -32,11 +40,11 @@ CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache") #: A dictionary with file names as keys and lists of the mypy stdout as values. #: To-be populated by `run_mypy`. -OUTPUT_MYPY: Dict[str, List[str]] = {} +OUTPUT_MYPY: dict[str, list[str]] = {} def _key_func(key: str) -> str: - """Split at the first occurance of the ``:`` character. + """Split at the first occurrence of the ``:`` character. Windows drive-letters (*e.g.* ``C:``) are ignored herein. """ @@ -62,7 +70,10 @@ def run_mypy() -> None: NUMPY_TYPING_TEST_CLEAR_CACHE=0 pytest numpy/typing/tests """ - if os.path.isdir(CACHE_DIR) and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True)): + if ( + os.path.isdir(CACHE_DIR) + and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True)) + ): shutil.rmtree(CACHE_DIR) for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR, MISC_DIR): @@ -85,25 +96,19 @@ def run_mypy() -> None: OUTPUT_MYPY.update((k, list(v)) for k, v in iterator if k) -def get_test_cases(directory): +def get_test_cases(directory: str) -> Iterator[ParameterSet]: for root, _, files in os.walk(directory): for fname in files: - if os.path.splitext(fname)[-1] == ".py": + short_fname, ext = os.path.splitext(fname) + if ext in (".pyi", ".py"): fullpath = os.path.join(root, fname) - # Use relative path for nice py.test name - relpath = os.path.relpath(fullpath, start=directory) - - yield pytest.param( - fullpath, - # Manually specify a name for the test - id=relpath, - ) + yield pytest.param(fullpath, id=short_fname) @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.mark.parametrize("path", get_test_cases(PASS_DIR)) -def test_success(path): +def test_success(path) -> None: # Alias `OUTPUT_MYPY` so that it appears in the local namespace output_mypy = OUTPUT_MYPY if path in output_mypy: @@ -115,7 +120,7 @@ def test_success(path): @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.mark.parametrize("path", get_test_cases(FAIL_DIR)) -def test_fail(path): +def test_fail(path: str) -> None: __tracebackhide__ = True with open(path) as fin: @@ -138,38 +143,86 @@ def test_fail(path): for i, line in enumerate(lines): lineno = i + 1 - if line.startswith('#') or (" E:" not in line and lineno not in errors): + if ( + line.startswith('#') + or (" E:" not in line and lineno not in errors) + ): continue target_line = lines[lineno - 1] if "# E:" in target_line: - marker = target_line.split("# E:")[-1].strip() - expected_error = errors.get(lineno) - _test_fail(path, marker, expected_error, lineno) + expression, _, marker = target_line.partition(" # E: ") + expected_error = errors[lineno].strip() + marker = marker.strip() + _test_fail(path, expression, marker, expected_error, lineno) else: - pytest.fail(f"Unexpected mypy output\n\n{errors[lineno]}") + pytest.fail( + f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}" + ) _FAIL_MSG1 = """Extra error at line {} +Expression: {} Extra error: {!r} """ _FAIL_MSG2 = """Error mismatch at line {} +Expression: {} Expected error: {!r} Observed error: {!r} """ -def _test_fail(path: str, error: str, expected_error: Optional[str], lineno: int) -> None: +def _test_fail( + path: str, + expression: str, + error: str, + expected_error: None | str, + lineno: int, +) -> None: if expected_error is None: - raise AssertionError(_FAIL_MSG1.format(lineno, error)) + raise AssertionError(_FAIL_MSG1.format(lineno, expression, error)) elif error not in expected_error: - raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error)) + raise AssertionError(_FAIL_MSG2.format( + lineno, expression, expected_error, error + )) + + +def _construct_ctypes_dict() -> dict[str, str]: + dct = { + "ubyte": "c_ubyte", + "ushort": "c_ushort", + "uintc": "c_uint", + "uint": "c_ulong", + "ulonglong": "c_ulonglong", + "byte": "c_byte", + "short": "c_short", + "intc": "c_int", + "int_": "c_long", + "longlong": "c_longlong", + "single": "c_float", + "double": "c_double", + "longdouble": "c_longdouble", + } + + # Match `ctypes` names to the first ctypes type with a given kind and + # precision, e.g. {"c_double": "c_double", "c_longdouble": "c_double"} + # if both types represent 64-bit floats. + # In this context "first" is defined by the order of `dct` + ret = {} + visited: dict[tuple[str, int], str] = {} + for np_name, ct_name in dct.items(): + np_scalar = getattr(np, np_name)() + + # Find the first `ctypes` type for a given `kind`/`itemsize` combo + key = (np_scalar.dtype.kind, np_scalar.dtype.itemsize) + ret[ct_name] = visited.setdefault(key, f"ctypes.{ct_name}") + return ret -def _construct_format_dict(): +def _construct_format_dict() -> dict[str, str]: dct = {k.split(".")[-1]: v.replace("numpy", "numpy.typing") for k, v in _PRECISION_DICT.items()} @@ -193,12 +246,18 @@ def _construct_format_dict(): "float96": "numpy.floating[numpy.typing._96Bit]", "float128": "numpy.floating[numpy.typing._128Bit]", "float256": "numpy.floating[numpy.typing._256Bit]", - "complex64": "numpy.complexfloating[numpy.typing._32Bit, numpy.typing._32Bit]", - "complex128": "numpy.complexfloating[numpy.typing._64Bit, numpy.typing._64Bit]", - "complex160": "numpy.complexfloating[numpy.typing._80Bit, numpy.typing._80Bit]", - "complex192": "numpy.complexfloating[numpy.typing._96Bit, numpy.typing._96Bit]", - "complex256": "numpy.complexfloating[numpy.typing._128Bit, numpy.typing._128Bit]", - "complex512": "numpy.complexfloating[numpy.typing._256Bit, numpy.typing._256Bit]", + "complex64": ("numpy.complexfloating" + "[numpy.typing._32Bit, numpy.typing._32Bit]"), + "complex128": ("numpy.complexfloating" + "[numpy.typing._64Bit, numpy.typing._64Bit]"), + "complex160": ("numpy.complexfloating" + "[numpy.typing._80Bit, numpy.typing._80Bit]"), + "complex192": ("numpy.complexfloating" + "[numpy.typing._96Bit, numpy.typing._96Bit]"), + "complex256": ("numpy.complexfloating" + "[numpy.typing._128Bit, numpy.typing._128Bit]"), + "complex512": ("numpy.complexfloating" + "[numpy.typing._256Bit, numpy.typing._256Bit]"), "ubyte": f"numpy.unsignedinteger[{dct['_NBitByte']}]", "ushort": f"numpy.unsignedinteger[{dct['_NBitShort']}]", @@ -217,9 +276,14 @@ def _construct_format_dict(): "single": f"numpy.floating[{dct['_NBitSingle']}]", "double": f"numpy.floating[{dct['_NBitDouble']}]", "longdouble": f"numpy.floating[{dct['_NBitLongDouble']}]", - "csingle": f"numpy.complexfloating[{dct['_NBitSingle']}, {dct['_NBitSingle']}]", - "cdouble": f"numpy.complexfloating[{dct['_NBitDouble']}, {dct['_NBitDouble']}]", - "clongdouble": f"numpy.complexfloating[{dct['_NBitLongDouble']}, {dct['_NBitLongDouble']}]", + "csingle": ("numpy.complexfloating" + f"[{dct['_NBitSingle']}, {dct['_NBitSingle']}]"), + "cdouble": ("numpy.complexfloating" + f"[{dct['_NBitDouble']}, {dct['_NBitDouble']}]"), + "clongdouble": ( + "numpy.complexfloating" + f"[{dct['_NBitLongDouble']}, {dct['_NBitLongDouble']}]" + ), # numpy.typing "_NBitInt": dct['_NBitInt'], @@ -231,40 +295,49 @@ def _construct_format_dict(): #: A dictionary with all supported format keys (as keys) #: and matching values -FORMAT_DICT: Dict[str, str] = _construct_format_dict() +FORMAT_DICT: dict[str, str] = _construct_format_dict() +FORMAT_DICT.update(_construct_ctypes_dict()) -def _parse_reveals(file: IO[str]) -> List[str]: - """Extract and parse all ``" # E: "`` comments from the passed file-like object. +def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]: + """Extract and parse all ``" # E: "`` comments from the passed + file-like object. - All format keys will be substituted for their respective value from `FORMAT_DICT`, - *e.g.* ``"{float64}"`` becomes ``"numpy.floating[numpy.typing._64Bit]"``. + All format keys will be substituted for their respective value + from `FORMAT_DICT`, *e.g.* ``"{float64}"`` becomes + ``"numpy.floating[numpy.typing._64Bit]"``. """ string = file.read().replace("*", "") - # Grab all `# E:`-based comments - comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2] + # Grab all `# E:`-based comments and matching expressions + expression_array, _, comments_array = np.char.partition( + string.split("\n"), sep=" # E: " + ).T comments = "/n".join(comments_array) - # Only search for the `{*}` pattern within comments, - # otherwise there is the risk of accidently grabbing dictionaries and sets + # Only search for the `{*}` pattern within comments, otherwise + # there is the risk of accidentally grabbing dictionaries and sets key_set = set(re.findall(r"\{(.*?)\}", comments)) kwargs = { - k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for k in key_set + k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for + k in key_set } fmt_str = comments.format(**kwargs) - return fmt_str.split("/n") + return expression_array, fmt_str.split("/n") @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR)) -def test_reveal(path): +def test_reveal(path: str) -> None: + """Validate that mypy correctly infers the return-types of + the expressions in `path`. + """ __tracebackhide__ = True with open(path) as fin: - lines = _parse_reveals(fin) + expression_array, reveal_list = _parse_reveals(fin) output_mypy = OUTPUT_MYPY assert path in output_mypy @@ -279,29 +352,47 @@ def test_reveal(path): lineno = int(match.group('lineno')) - 1 assert "Revealed type is" in error_line - marker = lines[lineno] - _test_reveal(path, marker, error_line, 1 + lineno) + marker = reveal_list[lineno] + expression = expression_array[lineno] + _test_reveal(path, expression, marker, error_line, 1 + lineno) _REVEAL_MSG = """Reveal mismatch at line {} +Expression: {} Expected reveal: {!r} Observed reveal: {!r} """ -def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> None: +def _test_reveal( + path: str, + expression: str, + reveal: str, + expected_reveal: str, + lineno: int, +) -> None: + """Error-reporting helper function for `test_reveal`.""" if reveal not in expected_reveal: - raise AssertionError(_REVEAL_MSG.format(lineno, expected_reveal, reveal)) + raise AssertionError( + _REVEAL_MSG.format(lineno, expression, expected_reveal, reveal) + ) @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") @pytest.mark.parametrize("path", get_test_cases(PASS_DIR)) -def test_code_runs(path): +def test_code_runs(path: str) -> None: + """Validate that the code in `path` properly during runtime.""" path_without_extension, _ = os.path.splitext(path) dirname, filename = path.split(os.sep)[-2:] - spec = importlib.util.spec_from_file_location(f"{dirname}.{filename}", path) + + spec = importlib.util.spec_from_file_location( + f"{dirname}.{filename}", path + ) + assert spec is not None + assert spec.loader is not None + test_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(test_module) @@ -325,15 +416,19 @@ LINENO_MAPPING = { @pytest.mark.slow @pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed") def test_extended_precision() -> None: - path = os.path.join(MISC_DIR, "extended_precision.py") + path = os.path.join(MISC_DIR, "extended_precision.pyi") output_mypy = OUTPUT_MYPY assert path in output_mypy + with open(path, "r") as f: + expression_list = f.readlines() + for _msg in output_mypy[path]: *_, _lineno, msg_typ, msg = _msg.split(":") msg = _strip_filename(msg) lineno = int(_lineno) + expression = expression_list[lineno - 1].rstrip("\n") msg_typ = msg_typ.strip() assert msg_typ in {"error", "note"} @@ -342,8 +437,8 @@ def test_extended_precision() -> None: raise ValueError(f"Unexpected reveal line format: {lineno}") else: marker = FORMAT_DICT[LINENO_MAPPING[lineno]] - _test_reveal(path, marker, msg, lineno) + _test_reveal(path, expression, marker, msg, lineno) else: if msg_typ == "error": marker = "Module has no attribute" - _test_fail(path, marker, msg, lineno) + _test_fail(path, expression, marker, msg, lineno) |