# -*- coding: utf-8 -*- """ Tests for thread usage in lxml.etree. """ from __future__ import absolute_import import re import sys import unittest import threading from .common_imports import etree, HelperTestCase, BytesIO, _bytes try: from Queue import Queue except ImportError: from queue import Queue # Py3 class ThreadingTestCase(HelperTestCase): """Threading tests""" etree = etree def _run_thread(self, func): thread = threading.Thread(target=func) thread.start() thread.join() def _run_threads(self, count, func, main_func=None): sync = threading.Event() lock = threading.Lock() counter = dict(started=0, finished=0, failed=0) def sync_start(func): with lock: started = counter['started'] + 1 counter['started'] = started if started < count + (main_func is not None): sync.wait(4) # wait until the other threads have started up assert sync.is_set() sync.set() # all waiting => go! try: func() except: with lock: counter['failed'] += 1 raise else: with lock: counter['finished'] += 1 threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] for thread in threads: thread.start() if main_func is not None: sync_start(main_func) for thread in threads: thread.join() self.assertEqual(0, counter['failed']) self.assertEqual(counter['finished'], counter['started']) def test_subtree_copy_thread(self): tostring = self.etree.tostring XML = self.etree.XML xml = _bytes("") main_root = XML(_bytes("")) def run_thread(): thread_root = XML(xml) main_root.append(thread_root[0]) del thread_root self._run_thread(run_thread) self.assertEqual(xml, tostring(main_root)) def test_main_xslt_in_thread(self): XML = self.etree.XML style = XML(_bytes('''\ ''')) st = etree.XSLT(style) result = [] def run_thread(): root = XML(_bytes('BC')) result.append( st(root) ) self._run_thread(run_thread) self.assertEqual('''\ B ''', str(result[0])) def test_thread_xslt(self): XML = self.etree.XML tostring = self.etree.tostring root = XML(_bytes('BC')) def run_thread(): style = XML(_bytes('''\ ''')) st = etree.XSLT(style) root.append( st(root).getroot() ) self._run_thread(run_thread) self.assertEqual(_bytes('BCB'), tostring(root)) def test_thread_xslt_parsing_error_log(self): style = self.parse('''\ ''' + '\n'.join('' % i for i in range(200)) + ''' ''') self.assertRaises(etree.XSLTParseError, etree.XSLT, style) error_logs = [] def run_thread(): try: etree.XSLT(style) except etree.XSLTParseError as e: error_logs.append(e.error_log) else: self.assertFalse(True, "XSLT parsing should have failed but didn't") self._run_threads(16, run_thread) self.assertEqual(16, len(error_logs)) last_log = None for log in error_logs: self.assertTrue(len(log)) if last_log is not None: self.assertEqual(len(last_log), len(log)) self.assertTrue(len(log) >= 2, len(log)) for error in log: self.assertTrue(':ERROR:XSLT:' in str(error), str(error)) self.assertTrue(any('UnExpectedElement' in str(error) for error in log), log) last_log = log def test_thread_xslt_apply_error_log(self): tree = self.parse('') style = self.parse('''\ FAIL ''' + '\n'.join(' ' % (i, i, i-1) for i in range(1, 256)) + ''' ''') self.assertRaises(etree.XSLTApplyError, etree.XSLT(style), tree) error_logs = [] def run_thread(): transform = etree.XSLT(style) try: transform(tree) except etree.XSLTApplyError: error_logs.append(transform.error_log) else: self.assertFalse(True, "XSLT parsing should have failed but didn't") self._run_threads(16, run_thread) self.assertEqual(16, len(error_logs)) last_log = None for log in error_logs: self.assertTrue(len(log)) if last_log is not None: self.assertEqual(len(last_log), len(log)) self.assertEqual(1, len(log)) for error in log: self.assertTrue(':ERROR:XSLT:' in str(error)) last_log = log def test_thread_xslt_attr_replace(self): # this is the only case in XSLT where the result tree can be # modified in-place XML = self.etree.XML tostring = self.etree.tostring style = self.etree.XSLT(XML(_bytes('''\ xyz '''))) result = [] def run_thread(): root = XML(_bytes('')) result.append( style(root).getroot() ) self._run_thread(run_thread) self.assertEqual(_bytes(''), tostring(result[0])) def test_thread_create_xslt(self): XML = self.etree.XML tostring = self.etree.tostring root = XML(_bytes('BC')) stylesheets = [] def run_thread(): style = XML(_bytes('''\
''')) stylesheets.append( etree.XSLT(style) ) self._run_thread(run_thread) st = stylesheets[0] result = tostring( st(root) ) self.assertEqual(_bytes('
BC
'), result) def test_thread_error_log(self): XML = self.etree.XML expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH] children = "test" * 100 def parse_error_test(thread_no): tag = "tag%d" % thread_no xml = "<%s>%s" % (tag, children, tag.upper()) parser = self.etree.XMLParser() for _ in range(10): errors = None try: XML(xml, parser) except self.etree.ParseError: e = sys.exc_info()[1] errors = e.error_log.filter_types(expected_error) self.assertTrue(errors, "Expected error not found") for error in errors: self.assertTrue( tag in error.message and tag.upper() in error.message, "%s and %s not found in '%s'" % ( tag, tag.upper(), error.message)) self.etree.clear_error_log() threads = [] for thread_no in range(1, 10): t = threading.Thread(target=parse_error_test, args=(thread_no,)) threads.append(t) t.start() parse_error_test(0) for t in threads: t.join() def test_thread_mix(self): XML = self.etree.XML Element = self.etree.Element SubElement = self.etree.SubElement tostring = self.etree.tostring xml = _bytes('BC') root = XML(xml) fragment = XML(_bytes("")) result = self.etree.Element("{myns}root", att = "someval") def run_XML(): thread_root = XML(xml) result.append(thread_root[0]) result.append(thread_root[-1]) def run_parse(): thread_root = self.etree.parse(BytesIO(xml)).getroot() result.append(thread_root[0]) result.append(thread_root[-1]) def run_move_main(): result.append(fragment[0]) def run_build(): result.append( Element("{myns}foo", attrib={'{test}attr':'val'})) SubElement(result, "{otherns}tasty") def run_xslt(): style = XML(_bytes('''\ ''')) st = etree.XSLT(style) result.append( st(root).getroot() ) for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): tostring(result) self._run_thread(test) self.assertEqual( _bytes('B' 'CBC' 'B' '' ''), tostring(result)) def strip_first(): root = Element("newroot") root.append(result[0]) while len(result): self._run_thread(strip_first) self.assertEqual( _bytes(''), tostring(result)) def test_concurrent_attribute_names_in_dicts(self): SubElement = self.etree.SubElement names = list('abcdefghijklmnop') runs_per_name = range(50) result_matches = re.compile( br'' br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' br'').match def testrun(): for _ in range(3): root = self.etree.Element('thread_root') for name in names: tag_name = name * 5 new = [] for _ in runs_per_name: el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) new.append(el) for el in new: el.set('thread_attr2_' + name, 'value2') s = etree.tostring(root) self.assertTrue(result_matches(s)) # first, run only in sub-threads self._run_threads(10, testrun) # then, additionally include the main thread (and its parent dict) self._run_threads(10, testrun, main_func=testrun) def test_concurrent_proxies(self): XML = self.etree.XML root = XML(_bytes('AB')) child_count = len(root) def testrun(): for i in range(10000): el = root[i%child_count] del el self._run_threads(10, testrun) def test_concurrent_class_lookup(self): XML = self.etree.XML class TestElement(etree.ElementBase): pass class MyLookup(etree.CustomElementClassLookup): repeat = range(100) def lookup(self, t, d, ns, name): count = 0 for i in self.repeat: # allow other threads to run count += 1 return TestElement parser = self.etree.XMLParser() parser.set_element_class_lookup(MyLookup()) root = XML(_bytes('AB'), parser) child_count = len(root) def testrun(): for i in range(1000): el = root[i%child_count] del el self._run_threads(10, testrun) class ThreadPipelineTestCase(HelperTestCase): """Threading tests based on a thread worker pipeline. """ etree = etree item_count = 40 class Worker(threading.Thread): def __init__(self, in_queue, in_count, **kwargs): threading.Thread.__init__(self) self.in_queue = in_queue self.in_count = in_count self.out_queue = Queue(in_count) self.__dict__.update(kwargs) def run(self): get, put = self.in_queue.get, self.out_queue.put handle = self.handle for _ in range(self.in_count): put(handle(get())) def handle(self, data): raise NotImplementedError() class ParseWorker(Worker): def handle(self, xml, _fromstring=etree.fromstring): return _fromstring(xml) class RotateWorker(Worker): def handle(self, element): first = element[0] element[:] = element[1:] element.append(first) return element class ReverseWorker(Worker): def handle(self, element): element[:] = element[::-1] return element class ParseAndExtendWorker(Worker): def handle(self, element, _fromstring=etree.fromstring): element.extend(_fromstring(self.xml)) return element class ParseAndInjectWorker(Worker): def handle(self, element, _fromstring=etree.fromstring): root = _fromstring(self.xml) root.extend(element) return root class Validate(Worker): def handle(self, element): element.getroottree().docinfo.internalDTD.assertValid(element) return element class SerialiseWorker(Worker): def handle(self, element): return etree.tostring(element) xml = (b'''\ ]> ''' + (b'''
''') * 20 + b'''
''') def _build_pipeline(self, item_count, *classes, **kwargs): in_queue = Queue(item_count) start = last = classes[0](in_queue, item_count, **kwargs) start.setDaemon(True) for worker_class in classes[1:]: last = worker_class(last.out_queue, item_count, **kwargs) last.setDaemon(True) last.start() return in_queue, start, last def test_thread_pipeline_thread_parse(self): item_count = self.item_count xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names # build and start the pipeline in_queue, start, last = self._build_pipeline( item_count, self.ParseWorker, self.RotateWorker, self.ReverseWorker, self.ParseAndExtendWorker, self.Validate, self.ParseAndInjectWorker, self.SerialiseWorker, xml=xml) # fill the queue put = start.in_queue.put for _ in range(item_count): put(xml) # start the first thread and thus everything start.start() # make sure the last thread has terminated last.join(60) # time out after 60 seconds self.assertEqual(item_count, last.out_queue.qsize()) # read the results get = last.out_queue.get results = [get() for _ in range(item_count)] comparison = results[0] for i, result in enumerate(results[1:]): self.assertEqual(comparison, result) def test_thread_pipeline_global_parse(self): item_count = self.item_count xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names XML = self.etree.XML # build and start the pipeline in_queue, start, last = self._build_pipeline( item_count, self.RotateWorker, self.ReverseWorker, self.ParseAndExtendWorker, self.Validate, self.SerialiseWorker, xml=xml) # fill the queue put = start.in_queue.put for _ in range(item_count): put(XML(xml)) # start the first thread and thus everything start.start() # make sure the last thread has terminated last.join(60) # time out after 90 seconds self.assertEqual(item_count, last.out_queue.qsize()) # read the results get = last.out_queue.get results = [get() for _ in range(item_count)] comparison = results[0] for i, result in enumerate(results[1:]): self.assertEqual(comparison, result) def test_suite(): suite = unittest.TestSuite() suite.addTests([unittest.makeSuite(ThreadingTestCase)]) suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)]) return suite if __name__ == '__main__': print('to test use test.py %s' % __file__)