diff options
Diffstat (limited to 'Cython/TestUtils.py')
-rw-r--r-- | Cython/TestUtils.py | 227 |
1 files changed, 188 insertions, 39 deletions
diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 9d6eb67fc..d3a34e741 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -1,8 +1,14 @@ from __future__ import absolute_import import os +import re import unittest +import shlex +import sys import tempfile +import textwrap +from io import open +from functools import partial from .Compiler import Errors from .CodeWriter import CodeWriter @@ -47,13 +53,10 @@ def treetypes(root): class CythonTest(unittest.TestCase): def setUp(self): - self.listing_file = Errors.listing_file - self.echo_file = Errors.echo_file - Errors.listing_file = Errors.echo_file = None + Errors.init_thread() def tearDown(self): - Errors.listing_file = self.listing_file - Errors.echo_file = self.echo_file + Errors.init_thread() def assertLines(self, expected, result): "Checks that the given strings or lists of strings are equal line by line" @@ -160,11 +163,87 @@ class TransformTest(CythonTest): return tree +# For the test C code validation, we have to take care that the test directives (and thus +# the match strings) do not just appear in (multiline) C code comments containing the original +# Cython source code. Thus, we discard the comments before matching. +# This seems a prime case for re.VERBOSE, but it seems to match some of the whitespace. +_strip_c_comments = partial(re.compile( + re.sub(r'\s+', '', r''' + /[*] ( + (?: [^*\n] | [*][^/] )* + [\n] + (?: [^*] | [*][^/] )* + ) [*]/ + ''') +).sub, '') + +_strip_cython_code_from_html = partial(re.compile( + re.sub(r'\s\s+', '', r''' + (?: + <pre class=["'][^"']*cython\s+line[^"']*["']\s*> + (?:[^<]|<(?!/pre))+ + </pre> + )|(?: + <style[^>]*> + (?:[^<]|<(?!/style))+ + </style> + ) + ''') +).sub, '') + + class TreeAssertVisitor(VisitorTransform): # actually, a TreeVisitor would be enough, but this needs to run # as part of the compiler pipeline - def visit_CompilerDirectivesNode(self, node): + def __init__(self): + super(TreeAssertVisitor, self).__init__() + self._module_pos = None + self._c_patterns = [] + self._c_antipatterns = [] + + def create_c_file_validator(self): + patterns, antipatterns = self._c_patterns, self._c_antipatterns + + def fail(pos, pattern, found, file_path): + Errors.error(pos, "Pattern '%s' %s found in %s" %( + pattern, + 'was' if found else 'was not', + file_path, + )) + + def validate_file_content(file_path, content): + for pattern in patterns: + #print("Searching pattern '%s'" % pattern) + if not re.search(pattern, content): + fail(self._module_pos, pattern, found=False, file_path=file_path) + + for antipattern in antipatterns: + #print("Searching antipattern '%s'" % antipattern) + if re.search(antipattern, content): + fail(self._module_pos, antipattern, found=True, file_path=file_path) + + def validate_c_file(result): + c_file = result.c_file + if not (patterns or antipatterns): + #print("No patterns defined for %s" % c_file) + return result + + with open(c_file, encoding='utf8') as f: + content = f.read() + content = _strip_c_comments(content) + validate_file_content(c_file, content) + + html_file = os.path.splitext(c_file)[0] + ".html" + if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file): + with open(html_file, encoding='utf8') as f: + content = f.read() + content = _strip_cython_code_from_html(content) + validate_file_content(html_file, content) + + return validate_c_file + + def _check_directives(self, node): directives = node.directives if 'test_assert_path_exists' in directives: for path in directives['test_assert_path_exists']: @@ -174,44 +253,114 @@ class TreeAssertVisitor(VisitorTransform): "Expected path '%s' not found in result tree" % path) if 'test_fail_if_path_exists' in directives: for path in directives['test_fail_if_path_exists']: - if TreePath.find_first(node, path) is not None: + first_node = TreePath.find_first(node, path) + if first_node is not None: Errors.error( - node.pos, - "Unexpected path '%s' found in result tree" % path) + first_node.pos, + "Unexpected path '%s' found in result tree" % path) + if 'test_assert_c_code_has' in directives: + self._c_patterns.extend(directives['test_assert_c_code_has']) + if 'test_fail_if_c_code_has' in directives: + self._c_antipatterns.extend(directives['test_fail_if_c_code_has']) + + def visit_ModuleNode(self, node): + self._module_pos = node.pos + self._check_directives(node) + self.visitchildren(node) + return node + + def visit_CompilerDirectivesNode(self, node): + self._check_directives(node) self.visitchildren(node) return node visit_Node = VisitorTransform.recurse_to_children -def unpack_source_tree(tree_file, dir=None): - if dir is None: - dir = tempfile.mkdtemp() - header = [] - cur_file = None - f = open(tree_file) - try: - lines = f.readlines() - finally: - f.close() - del f +def unpack_source_tree(tree_file, workdir, cython_root): + programs = { + 'PYTHON': [sys.executable], + 'CYTHON': [sys.executable, os.path.join(cython_root, 'cython.py')], + 'CYTHONIZE': [sys.executable, os.path.join(cython_root, 'cythonize.py')] + } + + if workdir is None: + workdir = tempfile.mkdtemp() + header, cur_file = [], None + with open(tree_file, 'rb') as f: + try: + for line in f: + if line[:5] == b'#####': + filename = line.strip().strip(b'#').strip().decode('utf8').replace('/', os.path.sep) + path = os.path.join(workdir, filename) + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + if cur_file is not None: + to_close, cur_file = cur_file, None + to_close.close() + cur_file = open(path, 'wb') + elif cur_file is not None: + cur_file.write(line) + elif line.strip() and not line.lstrip().startswith(b'#'): + if line.strip() not in (b'"""', b"'''"): + command = shlex.split(line.decode('utf8')) + if not command: continue + # In Python 3: prog, *args = command + prog, args = command[0], command[1:] + try: + header.append(programs[prog]+args) + except KeyError: + header.append(command) + finally: + if cur_file is not None: + cur_file.close() + return workdir, header + + +def write_file(file_path, content, dedent=False, encoding=None): + r"""Write some content (text or bytes) to the file + at `file_path` without translating `'\n'` into `os.linesep`. + + The default encoding is `'utf-8'`. + """ + if isinstance(content, bytes): + mode = "wb" + + # binary mode doesn't take an encoding and newline arguments + newline = None + default_encoding = None + else: + mode = "w" + + # any "\n" characters written are not translated + # to the system default line separator, os.linesep + newline = "\n" + default_encoding = "utf-8" + + if encoding is None: + encoding = default_encoding + + if dedent: + content = textwrap.dedent(content) + + with open(file_path, mode=mode, encoding=encoding, newline=newline) as f: + f.write(content) + + +def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None): + r""" + Write `content` to the file `file_path` without translating `'\n'` + into `os.linesep` and make sure it is newer than the file `newer_than`. + + The default encoding is `'utf-8'` (same as for `write_file`). + """ + write_file(file_path, content, dedent=dedent, encoding=encoding) + try: - for line in lines: - if line[:5] == '#####': - filename = line.strip().strip('#').strip().replace('/', os.path.sep) - path = os.path.join(dir, filename) - if not os.path.exists(os.path.dirname(path)): - os.makedirs(os.path.dirname(path)) - if cur_file is not None: - f, cur_file = cur_file, None - f.close() - cur_file = open(path, 'w') - elif cur_file is not None: - cur_file.write(line) - elif line.strip() and not line.lstrip().startswith('#'): - if line.strip() not in ('"""', "'''"): - header.append(line) - finally: - if cur_file is not None: - cur_file.close() - return dir, ''.join(header) + other_time = os.path.getmtime(newer_than) + except OSError: + # Support writing a fresh file (which is always newer than a non-existant one) + other_time = None + + while other_time is None or other_time >= os.path.getmtime(file_path): + write_file(file_path, content, dedent=dedent, encoding=encoding) |