summaryrefslogtreecommitdiff
path: root/Cython/TestUtils.py
diff options
context:
space:
mode:
Diffstat (limited to 'Cython/TestUtils.py')
-rw-r--r--Cython/TestUtils.py227
1 files changed, 188 insertions, 39 deletions
diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py
index 9d6eb67fc..d3a34e741 100644
--- a/Cython/TestUtils.py
+++ b/Cython/TestUtils.py
@@ -1,8 +1,14 @@
from __future__ import absolute_import
import os
+import re
import unittest
+import shlex
+import sys
import tempfile
+import textwrap
+from io import open
+from functools import partial
from .Compiler import Errors
from .CodeWriter import CodeWriter
@@ -47,13 +53,10 @@ def treetypes(root):
class CythonTest(unittest.TestCase):
def setUp(self):
- self.listing_file = Errors.listing_file
- self.echo_file = Errors.echo_file
- Errors.listing_file = Errors.echo_file = None
+ Errors.init_thread()
def tearDown(self):
- Errors.listing_file = self.listing_file
- Errors.echo_file = self.echo_file
+ Errors.init_thread()
def assertLines(self, expected, result):
"Checks that the given strings or lists of strings are equal line by line"
@@ -160,11 +163,87 @@ class TransformTest(CythonTest):
return tree
+# For the test C code validation, we have to take care that the test directives (and thus
+# the match strings) do not just appear in (multiline) C code comments containing the original
+# Cython source code. Thus, we discard the comments before matching.
+# This seems a prime case for re.VERBOSE, but it seems to match some of the whitespace.
+_strip_c_comments = partial(re.compile(
+ re.sub(r'\s+', '', r'''
+ /[*] (
+ (?: [^*\n] | [*][^/] )*
+ [\n]
+ (?: [^*] | [*][^/] )*
+ ) [*]/
+ ''')
+).sub, '')
+
+_strip_cython_code_from_html = partial(re.compile(
+ re.sub(r'\s\s+', '', r'''
+ (?:
+ <pre class=["'][^"']*cython\s+line[^"']*["']\s*>
+ (?:[^<]|<(?!/pre))+
+ </pre>
+ )|(?:
+ <style[^>]*>
+ (?:[^<]|<(?!/style))+
+ </style>
+ )
+ ''')
+).sub, '')
+
+
class TreeAssertVisitor(VisitorTransform):
# actually, a TreeVisitor would be enough, but this needs to run
# as part of the compiler pipeline
- def visit_CompilerDirectivesNode(self, node):
+ def __init__(self):
+ super(TreeAssertVisitor, self).__init__()
+ self._module_pos = None
+ self._c_patterns = []
+ self._c_antipatterns = []
+
+ def create_c_file_validator(self):
+ patterns, antipatterns = self._c_patterns, self._c_antipatterns
+
+ def fail(pos, pattern, found, file_path):
+ Errors.error(pos, "Pattern '%s' %s found in %s" %(
+ pattern,
+ 'was' if found else 'was not',
+ file_path,
+ ))
+
+ def validate_file_content(file_path, content):
+ for pattern in patterns:
+ #print("Searching pattern '%s'" % pattern)
+ if not re.search(pattern, content):
+ fail(self._module_pos, pattern, found=False, file_path=file_path)
+
+ for antipattern in antipatterns:
+ #print("Searching antipattern '%s'" % antipattern)
+ if re.search(antipattern, content):
+ fail(self._module_pos, antipattern, found=True, file_path=file_path)
+
+ def validate_c_file(result):
+ c_file = result.c_file
+ if not (patterns or antipatterns):
+ #print("No patterns defined for %s" % c_file)
+ return result
+
+ with open(c_file, encoding='utf8') as f:
+ content = f.read()
+ content = _strip_c_comments(content)
+ validate_file_content(c_file, content)
+
+ html_file = os.path.splitext(c_file)[0] + ".html"
+ if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file):
+ with open(html_file, encoding='utf8') as f:
+ content = f.read()
+ content = _strip_cython_code_from_html(content)
+ validate_file_content(html_file, content)
+
+ return validate_c_file
+
+ def _check_directives(self, node):
directives = node.directives
if 'test_assert_path_exists' in directives:
for path in directives['test_assert_path_exists']:
@@ -174,44 +253,114 @@ class TreeAssertVisitor(VisitorTransform):
"Expected path '%s' not found in result tree" % path)
if 'test_fail_if_path_exists' in directives:
for path in directives['test_fail_if_path_exists']:
- if TreePath.find_first(node, path) is not None:
+ first_node = TreePath.find_first(node, path)
+ if first_node is not None:
Errors.error(
- node.pos,
- "Unexpected path '%s' found in result tree" % path)
+ first_node.pos,
+ "Unexpected path '%s' found in result tree" % path)
+ if 'test_assert_c_code_has' in directives:
+ self._c_patterns.extend(directives['test_assert_c_code_has'])
+ if 'test_fail_if_c_code_has' in directives:
+ self._c_antipatterns.extend(directives['test_fail_if_c_code_has'])
+
+ def visit_ModuleNode(self, node):
+ self._module_pos = node.pos
+ self._check_directives(node)
+ self.visitchildren(node)
+ return node
+
+ def visit_CompilerDirectivesNode(self, node):
+ self._check_directives(node)
self.visitchildren(node)
return node
visit_Node = VisitorTransform.recurse_to_children
-def unpack_source_tree(tree_file, dir=None):
- if dir is None:
- dir = tempfile.mkdtemp()
- header = []
- cur_file = None
- f = open(tree_file)
- try:
- lines = f.readlines()
- finally:
- f.close()
- del f
+def unpack_source_tree(tree_file, workdir, cython_root):
+ programs = {
+ 'PYTHON': [sys.executable],
+ 'CYTHON': [sys.executable, os.path.join(cython_root, 'cython.py')],
+ 'CYTHONIZE': [sys.executable, os.path.join(cython_root, 'cythonize.py')]
+ }
+
+ if workdir is None:
+ workdir = tempfile.mkdtemp()
+ header, cur_file = [], None
+ with open(tree_file, 'rb') as f:
+ try:
+ for line in f:
+ if line[:5] == b'#####':
+ filename = line.strip().strip(b'#').strip().decode('utf8').replace('/', os.path.sep)
+ path = os.path.join(workdir, filename)
+ if not os.path.exists(os.path.dirname(path)):
+ os.makedirs(os.path.dirname(path))
+ if cur_file is not None:
+ to_close, cur_file = cur_file, None
+ to_close.close()
+ cur_file = open(path, 'wb')
+ elif cur_file is not None:
+ cur_file.write(line)
+ elif line.strip() and not line.lstrip().startswith(b'#'):
+ if line.strip() not in (b'"""', b"'''"):
+ command = shlex.split(line.decode('utf8'))
+ if not command: continue
+ # In Python 3: prog, *args = command
+ prog, args = command[0], command[1:]
+ try:
+ header.append(programs[prog]+args)
+ except KeyError:
+ header.append(command)
+ finally:
+ if cur_file is not None:
+ cur_file.close()
+ return workdir, header
+
+
+def write_file(file_path, content, dedent=False, encoding=None):
+ r"""Write some content (text or bytes) to the file
+ at `file_path` without translating `'\n'` into `os.linesep`.
+
+ The default encoding is `'utf-8'`.
+ """
+ if isinstance(content, bytes):
+ mode = "wb"
+
+ # binary mode doesn't take an encoding and newline arguments
+ newline = None
+ default_encoding = None
+ else:
+ mode = "w"
+
+ # any "\n" characters written are not translated
+ # to the system default line separator, os.linesep
+ newline = "\n"
+ default_encoding = "utf-8"
+
+ if encoding is None:
+ encoding = default_encoding
+
+ if dedent:
+ content = textwrap.dedent(content)
+
+ with open(file_path, mode=mode, encoding=encoding, newline=newline) as f:
+ f.write(content)
+
+
+def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None):
+ r"""
+ Write `content` to the file `file_path` without translating `'\n'`
+ into `os.linesep` and make sure it is newer than the file `newer_than`.
+
+ The default encoding is `'utf-8'` (same as for `write_file`).
+ """
+ write_file(file_path, content, dedent=dedent, encoding=encoding)
+
try:
- for line in lines:
- if line[:5] == '#####':
- filename = line.strip().strip('#').strip().replace('/', os.path.sep)
- path = os.path.join(dir, filename)
- if not os.path.exists(os.path.dirname(path)):
- os.makedirs(os.path.dirname(path))
- if cur_file is not None:
- f, cur_file = cur_file, None
- f.close()
- cur_file = open(path, 'w')
- elif cur_file is not None:
- cur_file.write(line)
- elif line.strip() and not line.lstrip().startswith('#'):
- if line.strip() not in ('"""', "'''"):
- header.append(line)
- finally:
- if cur_file is not None:
- cur_file.close()
- return dir, ''.join(header)
+ other_time = os.path.getmtime(newer_than)
+ except OSError:
+ # Support writing a fresh file (which is always newer than a non-existant one)
+ other_time = None
+
+ while other_time is None or other_time >= os.path.getmtime(file_path):
+ write_file(file_path, content, dedent=dedent, encoding=encoding)