"""
Common helpers and adaptations for Py2/3.
To be used in tests.
"""
# Slows down test runs by factors. Enable to debug proxy handling issues.
DEBUG_PROXY_ISSUES = False # True
import gc
import os
import os.path
import re
import sys
import tempfile
import unittest
from contextlib import contextmanager
try:
import urlparse
except ImportError:
import urllib.parse as urlparse
try:
from urllib import pathname2url
except:
from urllib.request import pathname2url
from lxml import etree, html
def make_version_tuple(version_string):
return tuple(
int(part) if part.isdigit() else part
for part in re.findall('([0-9]+|[^0-9.]+)', version_string)
)
IS_PYPY = (getattr(sys, 'implementation', None) == 'pypy' or
getattr(sys, 'pypy_version_info', None) is not None)
IS_PYTHON3 = sys.version_info[0] >= 3
IS_PYTHON2 = sys.version_info[0] < 3
from xml.etree import ElementTree
if hasattr(ElementTree, 'VERSION'):
ET_VERSION = make_version_tuple(ElementTree.VERSION)
else:
ET_VERSION = (0,0,0)
if IS_PYTHON2:
from xml.etree import cElementTree
if hasattr(cElementTree, 'VERSION'):
CET_VERSION = make_version_tuple(cElementTree.VERSION)
else:
CET_VERSION = (0,0,0)
else:
CET_VERSION = (0, 0, 0)
cElementTree = None
def filter_by_version(test_class, version_dict, current_version):
"""Remove test methods that do not work with the current lib version.
"""
find_required_version = version_dict.get
def dummy_test_method(self):
pass
for name in dir(test_class):
expected_version = find_required_version(name, (0,0,0))
if expected_version > current_version:
setattr(test_class, name, dummy_test_method)
def needs_libxml(*version):
return unittest.skipIf(
etree.LIBXML_VERSION < version,
"needs libxml2 >= %s.%s.%s" % (version + (0, 0, 0))[:3])
import doctest
try:
import pytest
except ImportError:
class skipif(object):
"Using a class because a function would bind into a method when used in classes"
def __init__(self, *args): pass
def __call__(self, func, *args): return func
else:
skipif = pytest.mark.skipif
def _get_caller_relative_path(filename, frame_depth=2):
module = sys.modules[sys._getframe(frame_depth).f_globals['__name__']]
return os.path.normpath(os.path.join(
os.path.dirname(getattr(module, '__file__', '')), filename))
from io import StringIO
unichr_escape = re.compile(r'\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}')
if sys.version_info[0] >= 3:
# Python 3
from builtins import str as unicode
from codecs import unicode_escape_decode
_chr = chr
def _str(s, encoding="UTF-8"):
return unichr_escape.sub(lambda x: unicode_escape_decode(x.group(0))[0], s)
def _bytes(s, encoding="UTF-8"):
return s.encode(encoding)
from io import BytesIO as _BytesIO
def BytesIO(*args):
if args and isinstance(args[0], str):
args = (args[0].encode("UTF-8"),)
return _BytesIO(*args)
doctest_parser = doctest.DocTestParser()
_fix_unicode = re.compile(r'(\s+)u(["\'])').sub
_fix_exceptions = re.compile(r'(.*except [^(]*),\s*(.*:)').sub
def make_doctest(filename):
filename = _get_caller_relative_path(filename)
doctests = read_file(filename)
doctests = _fix_unicode(r'\1\2', doctests)
doctests = _fix_exceptions(r'\1 as \2', doctests)
return doctest.DocTestCase(
doctest_parser.get_doctest(
doctests, {}, os.path.basename(filename), filename, 0))
else:
# Python 2
from __builtin__ import unicode
_chr = unichr
def _str(s, encoding="UTF-8"):
s = unicode(s, encoding=encoding)
return unichr_escape.sub(lambda x:
x.group(0).decode('unicode-escape'),
s)
def _bytes(s, encoding="UTF-8"):
return s
from io import BytesIO
doctest_parser = doctest.DocTestParser()
_fix_traceback = re.compile(r'^(\s*)(?:\w+\.)+(\w*(?:Error|Exception|Invalid):)', re.M).sub
_fix_exceptions = re.compile(r'(.*except [^(]*)\s+as\s+(.*:)').sub
_fix_bytes = re.compile(r'(\s+)b(["\'])').sub
def make_doctest(filename):
filename = _get_caller_relative_path(filename)
doctests = read_file(filename)
doctests = _fix_traceback(r'\1\2', doctests)
doctests = _fix_exceptions(r'\1, \2', doctests)
doctests = _fix_bytes(r'\1\2', doctests)
return doctest.DocTestCase(
doctest_parser.get_doctest(
doctests, {}, os.path.basename(filename), filename, 0))
try:
skipIf = unittest.skipIf
except AttributeError:
def skipIf(condition, why):
def _skip(thing):
import types
if isinstance(thing, (type, types.ClassType)):
return type(thing.__name__, (object,), {})
else:
return None
if condition:
return _skip
return lambda thing: thing
class HelperTestCase(unittest.TestCase):
def tearDown(self):
if DEBUG_PROXY_ISSUES:
gc.collect()
def parse(self, text, parser=None):
f = BytesIO(text) if isinstance(text, bytes) else StringIO(text)
return etree.parse(f, parser=parser)
def _rootstring(self, tree):
return etree.tostring(tree.getroot()).replace(
_bytes(' '), _bytes('')).replace(_bytes('\n'), _bytes(''))
class SillyFileLike:
def __init__(self, xml_data=_bytes('')):
self.xml_data = xml_data
def read(self, amount=None):
if self.xml_data:
if amount:
data = self.xml_data[:amount]
self.xml_data = self.xml_data[amount:]
else:
data = self.xml_data
self.xml_data = _bytes('')
return data
return _bytes('')
class LargeFileLike:
def __init__(self, charlen=100, depth=4, children=5):
self.data = BytesIO()
self.chars = _bytes('a') * charlen
self.children = range(children)
self.more = self.iterelements(depth)
def iterelements(self, depth):
yield _bytes('')
depth -= 1
if depth > 0:
for child in self.children:
for element in self.iterelements(depth):
yield element
yield self.chars
else:
yield self.chars
yield _bytes('')
def read(self, amount=None):
data = self.data
append = data.write
if amount:
for element in self.more:
append(element)
if data.tell() >= amount:
break
else:
for element in self.more:
append(element)
result = data.getvalue()
data.seek(0)
data.truncate()
if amount:
append(result[amount:])
result = result[:amount]
return result
class LargeFileLikeUnicode(LargeFileLike):
def __init__(self, charlen=100, depth=4, children=5):
LargeFileLike.__init__(self, charlen, depth, children)
self.data = StringIO()
self.chars = _str('a') * charlen
self.more = self.iterelements(depth)
def iterelements(self, depth):
yield _str('')
depth -= 1
if depth > 0:
for child in self.children:
for element in self.iterelements(depth):
yield element
yield self.chars
else:
yield self.chars
yield _str('')
class SimpleFSPath(object):
def __init__(self, path):
self.path = path
def __fspath__(self):
return self.path
def fileInTestDir(name):
_testdir = os.path.dirname(__file__)
return os.path.join(_testdir, name)
def path2url(path):
return urlparse.urljoin(
'file:', pathname2url(path))
def fileUrlInTestDir(name):
return path2url(fileInTestDir(name))
def read_file(name, mode='r'):
with open(name, mode) as f:
data = f.read()
return data
def write_to_file(name, data, mode='w'):
with open(name, mode) as f:
f.write(data)
def readFileInTestDir(name, mode='r'):
return read_file(fileInTestDir(name), mode)
def canonicalize(xml):
tree = etree.parse(BytesIO(xml) if isinstance(xml, bytes) else StringIO(xml))
f = BytesIO()
tree.write_c14n(f)
return f.getvalue()
@contextmanager
def tmpfile(**kwargs):
handle, filename = tempfile.mkstemp(**kwargs)
try:
yield filename
finally:
os.close(handle)
os.remove(filename)