summaryrefslogtreecommitdiff
path: root/numpy/typing/tests/test_typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/tests/test_typing.py')
-rw-r--r--numpy/typing/tests/test_typing.py207
1 files changed, 151 insertions, 56 deletions
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index 35558c880..2dcfd6082 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -1,13 +1,17 @@
+from __future__ import annotations
+
import importlib.util
import itertools
import os
import re
import shutil
from collections import defaultdict
-from typing import Optional, IO, Dict, List
+from collections.abc import Iterator
+from typing import IO, TYPE_CHECKING
import pytest
import numpy as np
+import numpy.typing as npt
from numpy.typing.mypy_plugin import (
_PRECISION_DICT,
_EXTENDED_PRECISION_LIST,
@@ -21,6 +25,10 @@ except ImportError:
else:
NO_MYPY = False
+if TYPE_CHECKING:
+ # We need this as annotation, but it's located in a private namespace.
+ # As a compromise, do *not* import it during runtime
+ from _pytest.mark.structures import ParameterSet
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
PASS_DIR = os.path.join(DATA_DIR, "pass")
@@ -32,11 +40,11 @@ CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
#: A dictionary with file names as keys and lists of the mypy stdout as values.
#: To-be populated by `run_mypy`.
-OUTPUT_MYPY: Dict[str, List[str]] = {}
+OUTPUT_MYPY: dict[str, list[str]] = {}
def _key_func(key: str) -> str:
- """Split at the first occurance of the ``:`` character.
+ """Split at the first occurrence of the ``:`` character.
Windows drive-letters (*e.g.* ``C:``) are ignored herein.
"""
@@ -62,7 +70,10 @@ def run_mypy() -> None:
NUMPY_TYPING_TEST_CLEAR_CACHE=0 pytest numpy/typing/tests
"""
- if os.path.isdir(CACHE_DIR) and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True)):
+ if (
+ os.path.isdir(CACHE_DIR)
+ and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True))
+ ):
shutil.rmtree(CACHE_DIR)
for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR, MISC_DIR):
@@ -85,25 +96,19 @@ def run_mypy() -> None:
OUTPUT_MYPY.update((k, list(v)) for k, v in iterator if k)
-def get_test_cases(directory):
+def get_test_cases(directory: str) -> Iterator[ParameterSet]:
for root, _, files in os.walk(directory):
for fname in files:
- if os.path.splitext(fname)[-1] == ".py":
+ short_fname, ext = os.path.splitext(fname)
+ if ext in (".pyi", ".py"):
fullpath = os.path.join(root, fname)
- # Use relative path for nice py.test name
- relpath = os.path.relpath(fullpath, start=directory)
-
- yield pytest.param(
- fullpath,
- # Manually specify a name for the test
- id=relpath,
- )
+ yield pytest.param(fullpath, id=short_fname)
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
-def test_success(path):
+def test_success(path) -> None:
# Alias `OUTPUT_MYPY` so that it appears in the local namespace
output_mypy = OUTPUT_MYPY
if path in output_mypy:
@@ -115,7 +120,7 @@ def test_success(path):
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
-def test_fail(path):
+def test_fail(path: str) -> None:
__tracebackhide__ = True
with open(path) as fin:
@@ -138,38 +143,86 @@ def test_fail(path):
for i, line in enumerate(lines):
lineno = i + 1
- if line.startswith('#') or (" E:" not in line and lineno not in errors):
+ if (
+ line.startswith('#')
+ or (" E:" not in line and lineno not in errors)
+ ):
continue
target_line = lines[lineno - 1]
if "# E:" in target_line:
- marker = target_line.split("# E:")[-1].strip()
- expected_error = errors.get(lineno)
- _test_fail(path, marker, expected_error, lineno)
+ expression, _, marker = target_line.partition(" # E: ")
+ expected_error = errors[lineno].strip()
+ marker = marker.strip()
+ _test_fail(path, expression, marker, expected_error, lineno)
else:
- pytest.fail(f"Unexpected mypy output\n\n{errors[lineno]}")
+ pytest.fail(
+ f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}"
+ )
_FAIL_MSG1 = """Extra error at line {}
+Expression: {}
Extra error: {!r}
"""
_FAIL_MSG2 = """Error mismatch at line {}
+Expression: {}
Expected error: {!r}
Observed error: {!r}
"""
-def _test_fail(path: str, error: str, expected_error: Optional[str], lineno: int) -> None:
+def _test_fail(
+ path: str,
+ expression: str,
+ error: str,
+ expected_error: None | str,
+ lineno: int,
+) -> None:
if expected_error is None:
- raise AssertionError(_FAIL_MSG1.format(lineno, error))
+ raise AssertionError(_FAIL_MSG1.format(lineno, expression, error))
elif error not in expected_error:
- raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))
+ raise AssertionError(_FAIL_MSG2.format(
+ lineno, expression, expected_error, error
+ ))
+
+
+def _construct_ctypes_dict() -> dict[str, str]:
+ dct = {
+ "ubyte": "c_ubyte",
+ "ushort": "c_ushort",
+ "uintc": "c_uint",
+ "uint": "c_ulong",
+ "ulonglong": "c_ulonglong",
+ "byte": "c_byte",
+ "short": "c_short",
+ "intc": "c_int",
+ "int_": "c_long",
+ "longlong": "c_longlong",
+ "single": "c_float",
+ "double": "c_double",
+ "longdouble": "c_longdouble",
+ }
+
+ # Match `ctypes` names to the first ctypes type with a given kind and
+ # precision, e.g. {"c_double": "c_double", "c_longdouble": "c_double"}
+ # if both types represent 64-bit floats.
+ # In this context "first" is defined by the order of `dct`
+ ret = {}
+ visited: dict[tuple[str, int], str] = {}
+ for np_name, ct_name in dct.items():
+ np_scalar = getattr(np, np_name)()
+
+ # Find the first `ctypes` type for a given `kind`/`itemsize` combo
+ key = (np_scalar.dtype.kind, np_scalar.dtype.itemsize)
+ ret[ct_name] = visited.setdefault(key, f"ctypes.{ct_name}")
+ return ret
-def _construct_format_dict():
+def _construct_format_dict() -> dict[str, str]:
dct = {k.split(".")[-1]: v.replace("numpy", "numpy.typing") for
k, v in _PRECISION_DICT.items()}
@@ -193,12 +246,18 @@ def _construct_format_dict():
"float96": "numpy.floating[numpy.typing._96Bit]",
"float128": "numpy.floating[numpy.typing._128Bit]",
"float256": "numpy.floating[numpy.typing._256Bit]",
- "complex64": "numpy.complexfloating[numpy.typing._32Bit, numpy.typing._32Bit]",
- "complex128": "numpy.complexfloating[numpy.typing._64Bit, numpy.typing._64Bit]",
- "complex160": "numpy.complexfloating[numpy.typing._80Bit, numpy.typing._80Bit]",
- "complex192": "numpy.complexfloating[numpy.typing._96Bit, numpy.typing._96Bit]",
- "complex256": "numpy.complexfloating[numpy.typing._128Bit, numpy.typing._128Bit]",
- "complex512": "numpy.complexfloating[numpy.typing._256Bit, numpy.typing._256Bit]",
+ "complex64": ("numpy.complexfloating"
+ "[numpy.typing._32Bit, numpy.typing._32Bit]"),
+ "complex128": ("numpy.complexfloating"
+ "[numpy.typing._64Bit, numpy.typing._64Bit]"),
+ "complex160": ("numpy.complexfloating"
+ "[numpy.typing._80Bit, numpy.typing._80Bit]"),
+ "complex192": ("numpy.complexfloating"
+ "[numpy.typing._96Bit, numpy.typing._96Bit]"),
+ "complex256": ("numpy.complexfloating"
+ "[numpy.typing._128Bit, numpy.typing._128Bit]"),
+ "complex512": ("numpy.complexfloating"
+ "[numpy.typing._256Bit, numpy.typing._256Bit]"),
"ubyte": f"numpy.unsignedinteger[{dct['_NBitByte']}]",
"ushort": f"numpy.unsignedinteger[{dct['_NBitShort']}]",
@@ -217,9 +276,14 @@ def _construct_format_dict():
"single": f"numpy.floating[{dct['_NBitSingle']}]",
"double": f"numpy.floating[{dct['_NBitDouble']}]",
"longdouble": f"numpy.floating[{dct['_NBitLongDouble']}]",
- "csingle": f"numpy.complexfloating[{dct['_NBitSingle']}, {dct['_NBitSingle']}]",
- "cdouble": f"numpy.complexfloating[{dct['_NBitDouble']}, {dct['_NBitDouble']}]",
- "clongdouble": f"numpy.complexfloating[{dct['_NBitLongDouble']}, {dct['_NBitLongDouble']}]",
+ "csingle": ("numpy.complexfloating"
+ f"[{dct['_NBitSingle']}, {dct['_NBitSingle']}]"),
+ "cdouble": ("numpy.complexfloating"
+ f"[{dct['_NBitDouble']}, {dct['_NBitDouble']}]"),
+ "clongdouble": (
+ "numpy.complexfloating"
+ f"[{dct['_NBitLongDouble']}, {dct['_NBitLongDouble']}]"
+ ),
# numpy.typing
"_NBitInt": dct['_NBitInt'],
@@ -231,40 +295,49 @@ def _construct_format_dict():
#: A dictionary with all supported format keys (as keys)
#: and matching values
-FORMAT_DICT: Dict[str, str] = _construct_format_dict()
+FORMAT_DICT: dict[str, str] = _construct_format_dict()
+FORMAT_DICT.update(_construct_ctypes_dict())
-def _parse_reveals(file: IO[str]) -> List[str]:
- """Extract and parse all ``" # E: "`` comments from the passed file-like object.
+def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]:
+ """Extract and parse all ``" # E: "`` comments from the passed
+ file-like object.
- All format keys will be substituted for their respective value from `FORMAT_DICT`,
- *e.g.* ``"{float64}"`` becomes ``"numpy.floating[numpy.typing._64Bit]"``.
+ All format keys will be substituted for their respective value
+ from `FORMAT_DICT`, *e.g.* ``"{float64}"`` becomes
+ ``"numpy.floating[numpy.typing._64Bit]"``.
"""
string = file.read().replace("*", "")
- # Grab all `# E:`-based comments
- comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2]
+ # Grab all `# E:`-based comments and matching expressions
+ expression_array, _, comments_array = np.char.partition(
+ string.split("\n"), sep=" # E: "
+ ).T
comments = "/n".join(comments_array)
- # Only search for the `{*}` pattern within comments,
- # otherwise there is the risk of accidently grabbing dictionaries and sets
+ # Only search for the `{*}` pattern within comments, otherwise
+ # there is the risk of accidentally grabbing dictionaries and sets
key_set = set(re.findall(r"\{(.*?)\}", comments))
kwargs = {
- k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for k in key_set
+ k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for
+ k in key_set
}
fmt_str = comments.format(**kwargs)
- return fmt_str.split("/n")
+ return expression_array, fmt_str.split("/n")
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
-def test_reveal(path):
+def test_reveal(path: str) -> None:
+ """Validate that mypy correctly infers the return-types of
+ the expressions in `path`.
+ """
__tracebackhide__ = True
with open(path) as fin:
- lines = _parse_reveals(fin)
+ expression_array, reveal_list = _parse_reveals(fin)
output_mypy = OUTPUT_MYPY
assert path in output_mypy
@@ -279,29 +352,47 @@ def test_reveal(path):
lineno = int(match.group('lineno')) - 1
assert "Revealed type is" in error_line
- marker = lines[lineno]
- _test_reveal(path, marker, error_line, 1 + lineno)
+ marker = reveal_list[lineno]
+ expression = expression_array[lineno]
+ _test_reveal(path, expression, marker, error_line, 1 + lineno)
_REVEAL_MSG = """Reveal mismatch at line {}
+Expression: {}
Expected reveal: {!r}
Observed reveal: {!r}
"""
-def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> None:
+def _test_reveal(
+ path: str,
+ expression: str,
+ reveal: str,
+ expected_reveal: str,
+ lineno: int,
+) -> None:
+ """Error-reporting helper function for `test_reveal`."""
if reveal not in expected_reveal:
- raise AssertionError(_REVEAL_MSG.format(lineno, expected_reveal, reveal))
+ raise AssertionError(
+ _REVEAL_MSG.format(lineno, expression, expected_reveal, reveal)
+ )
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
-def test_code_runs(path):
+def test_code_runs(path: str) -> None:
+ """Validate that the code in `path` properly during runtime."""
path_without_extension, _ = os.path.splitext(path)
dirname, filename = path.split(os.sep)[-2:]
- spec = importlib.util.spec_from_file_location(f"{dirname}.{filename}", path)
+
+ spec = importlib.util.spec_from_file_location(
+ f"{dirname}.{filename}", path
+ )
+ assert spec is not None
+ assert spec.loader is not None
+
test_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(test_module)
@@ -325,15 +416,19 @@ LINENO_MAPPING = {
@pytest.mark.slow
@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
def test_extended_precision() -> None:
- path = os.path.join(MISC_DIR, "extended_precision.py")
+ path = os.path.join(MISC_DIR, "extended_precision.pyi")
output_mypy = OUTPUT_MYPY
assert path in output_mypy
+ with open(path, "r") as f:
+ expression_list = f.readlines()
+
for _msg in output_mypy[path]:
*_, _lineno, msg_typ, msg = _msg.split(":")
msg = _strip_filename(msg)
lineno = int(_lineno)
+ expression = expression_list[lineno - 1].rstrip("\n")
msg_typ = msg_typ.strip()
assert msg_typ in {"error", "note"}
@@ -342,8 +437,8 @@ def test_extended_precision() -> None:
raise ValueError(f"Unexpected reveal line format: {lineno}")
else:
marker = FORMAT_DICT[LINENO_MAPPING[lineno]]
- _test_reveal(path, marker, msg, lineno)
+ _test_reveal(path, expression, marker, msg, lineno)
else:
if msg_typ == "error":
marker = "Module has no attribute"
- _test_fail(path, marker, msg, lineno)
+ _test_fail(path, expression, marker, msg, lineno)