summaryrefslogtreecommitdiff
path: root/numpy/tests/test_typing.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/tests/test_typing.py')
-rw-r--r--numpy/tests/test_typing.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/numpy/tests/test_typing.py b/numpy/tests/test_typing.py
new file mode 100644
index 000000000..757ea0b52
--- /dev/null
+++ b/numpy/tests/test_typing.py
@@ -0,0 +1,127 @@
+import importlib.util
+import itertools
+import os
+import re
+from collections import defaultdict
+
+import pytest
+try:
+ from mypy import api
+except ImportError:
+ NO_MYPY = True
+else:
+ NO_MYPY = False
+
+TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
+PASS_DIR = os.path.join(TESTS_DIR, "pass")
+FAIL_DIR = os.path.join(TESTS_DIR, "fail")
+REVEAL_DIR = os.path.join(TESTS_DIR, "reveal")
+MYPY_INI = os.path.join(TESTS_DIR, "mypy.ini")
+CACHE_DIR = os.path.join(TESTS_DIR, ".mypy_cache")
+
+
+def get_test_cases(directory):
+ for root, _, files in os.walk(directory):
+ for fname in files:
+ if os.path.splitext(fname)[-1] == ".py":
+ fullpath = os.path.join(root, fname)
+ # Use relative path for nice py.test name
+ relpath = os.path.relpath(fullpath, start=directory)
+
+ yield pytest.param(
+ fullpath,
+ # Manually specify a name for the test
+ id=relpath,
+ )
+
+
+@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
+@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
+def test_success(path):
+ stdout, stderr, exitcode = api.run([
+ "--config-file",
+ MYPY_INI,
+ "--cache-dir",
+ CACHE_DIR,
+ path,
+ ])
+ assert exitcode == 0, stdout
+ assert re.match(r"Success: no issues found in \d+ source files?", stdout.strip())
+
+
+@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
+@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
+def test_fail(path):
+ stdout, stderr, exitcode = api.run([
+ "--config-file",
+ MYPY_INI,
+ "--cache-dir",
+ CACHE_DIR,
+ path,
+ ])
+ assert exitcode != 0
+
+ with open(path) as fin:
+ lines = fin.readlines()
+
+ errors = defaultdict(lambda: "")
+ error_lines = stdout.rstrip("\n").split("\n")
+ assert re.match(
+ r"Found \d+ errors? in \d+ files? \(checked \d+ source files?\)",
+ error_lines[-1].strip(),
+ )
+ for error_line in error_lines[:-1]:
+ error_line = error_line.strip()
+ if not error_line:
+ continue
+
+ lineno = int(error_line.split(":")[1])
+ errors[lineno] += error_line
+
+ for i, line in enumerate(lines):
+ lineno = i + 1
+ if " E:" not in line and lineno not in errors:
+ continue
+
+ target_line = lines[lineno - 1]
+ if "# E:" in target_line:
+ marker = target_line.split("# E:")[-1].strip()
+ assert lineno in errors, f'Extra error "{marker}"'
+ assert marker in errors[lineno]
+ else:
+ pytest.fail(f"Error {repr(errors[lineno])} not found")
+
+
+@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
+@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
+def test_reveal(path):
+ stdout, stderr, exitcode = api.run([
+ "--config-file",
+ MYPY_INI,
+ "--cache-dir",
+ CACHE_DIR,
+ path,
+ ])
+
+ with open(path) as fin:
+ lines = fin.readlines()
+
+ for error_line in stdout.split("\n"):
+ error_line = error_line.strip()
+ if not error_line:
+ continue
+
+ lineno = int(error_line.split(":")[1])
+ assert "Revealed type is" in error_line
+ marker = lines[lineno - 1].split("# E:")[-1].strip()
+ assert marker in error_line
+
+
+@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
+@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
+def test_code_runs(path):
+ path_without_extension, _ = os.path.splitext(path)
+ dirname, filename = path.split(os.sep)[-2:]
+ spec = importlib.util.spec_from_file_location(f"{dirname}.{filename}", path)
+ test_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(test_module)