diff options
| author | Robert Collins <robertc@robertcollins.net> | 2011-04-25 16:25:24 +1200 |
|---|---|---|
| committer | Robert Collins <robertc@robertcollins.net> | 2011-04-25 16:25:24 +1200 |
| commit | 3b6f97b3c5ac79aefc74047b9292c4acee7ff09d (patch) | |
| tree | 10dbc28291a3d86f66b77bb5aaa58dc79899f162 /python/subunit | |
| parent | 63d596a30103017c7cac6af37d46c91a9cf7c325 (diff) | |
| download | subunit-3b6f97b3c5ac79aefc74047b9292c4acee7ff09d.tar.gz | |
Nearly done.
Diffstat (limited to 'python/subunit')
| -rw-r--r-- | python/subunit/__init__.py | 72 | ||||
| -rw-r--r-- | python/subunit/chunked.py | 41 | ||||
| -rw-r--r-- | python/subunit/details.py | 13 | ||||
| -rw-r--r-- | python/subunit/iso8601.py | 22 | ||||
| -rw-r--r-- | python/subunit/tests/test_chunked.py | 98 | ||||
| -rw-r--r-- | python/subunit/tests/test_details.py | 16 | ||||
| -rw-r--r-- | python/subunit/tests/test_subunit_filter.py | 12 | ||||
| -rw-r--r-- | python/subunit/tests/test_test_protocol.py | 79 |
8 files changed, 192 insertions, 161 deletions
diff --git a/python/subunit/__init__.py b/python/subunit/__init__.py index 6e1668e..8fb3ab7 100644 --- a/python/subunit/__init__.py +++ b/python/subunit/__init__.py @@ -129,9 +129,11 @@ try: RemoteException = _StringException # For testing: different pythons have different str() implementations. if sys.version_info > (3, 0): - _remote_exception_str = 'testtools.testresult.real._StringException' + _remote_exception_str = "testtools.testresult.real._StringException" + _remote_exception_str_chunked = "34\r\n" + _remote_exception_str else: - _remote_exception_str = '_StringException' + _remote_exception_str = "_StringException" + _remote_exception_str_chunked = "1A\r\n" + _remote_exception_str except ImportError: raise ImportError ("testtools.testresult.real does not contain " "_StringException, check your version.") @@ -559,7 +561,8 @@ class TestProtocolClient(testresult.TestResult): # Get a TestSuite or TestCase to run suite = make_suite() - # Create a stream (any object with a 'write' method) + # Create a stream (any object with a 'write' method). This should accept + # bytes not strings: subunit is a byte orientated protocol. stream = file('tests.log', 'wb') # Create a subunit result object which will output to the stream result = subunit.TestProtocolClient(stream) @@ -576,6 +579,14 @@ class TestProtocolClient(testresult.TestResult): testresult.TestResult.__init__(self) self._stream = stream _make_stream_binary(stream) + self._progress_fmt = _b("progress: ") + self._bytes_eol = _b("\n") + self._progress_plus = _b("+") + self._progress_push = _b("push") + self._progress_pop = _b("pop") + self._empty_bytes = _b("") + self._start_simple = _b(" [\n") + self._end_simple = _b("]\n") def addError(self, test, error=None, details=None): """Report an error in test test. @@ -637,42 +648,42 @@ class TestProtocolClient(testresult.TestResult): :param details: New Testing-in-python drafted API; a dict from string to subunit.Content objects. """ - self._stream.write("%s: %s" % (outcome, test.id())) + self._stream.write(_b("%s: %s" % (outcome, test.id()))) if error is None and details is None: raise ValueError if error is not None: - self._stream.write(" [\n") + self._stream.write(self._start_simple) # XXX: this needs to be made much stricter, along the lines of # Martin[gz]'s work in testtools. Perhaps subunit can use that? for line in self._exc_info_to_unicode(error, test).splitlines(): self._stream.write(("%s\n" % line).encode('utf8')) else: self._write_details(details) - self._stream.write("]\n") + self._stream.write(self._end_simple) def addSkip(self, test, reason=None, details=None): """Report a skipped test.""" if reason is None: self._addOutcome("skip", test, error=None, details=details) else: - self._stream.write("skip: %s [\n" % test.id()) - self._stream.write("%s\n" % reason) - self._stream.write("]\n") + self._stream.write(_b("skip: %s [\n" % test.id())) + self._stream.write(_b("%s\n" % reason)) + self._stream.write(self._end_simple) def addSuccess(self, test, details=None): """Report a success in a test.""" - self._stream.write("successful: %s" % test.id()) + self._stream.write(_b("successful: %s" % test.id())) if not details: - self._stream.write("\n") + self._stream.write(_b("\n")) else: self._write_details(details) - self._stream.write("]\n") + self._stream.write(self._end_simple) addUnexpectedSuccess = addSuccess def startTest(self, test): """Mark a test as starting its test run.""" super(TestProtocolClient, self).startTest(test) - self._stream.write("test: %s\n" % test.id()) + self._stream.write(_b("test: %s\n" % test.id())) self._stream.flush() def stopTest(self, test): @@ -690,16 +701,19 @@ class TestProtocolClient(testresult.TestResult): PROGRESS_POP. """ if whence == PROGRESS_CUR and offset > -1: - prefix = "+" + prefix = self._progress_plus + offset = _b(str(offset)) elif whence == PROGRESS_PUSH: - prefix = "" - offset = "push" + prefix = self._empty_bytes + offset = self._progress_push elif whence == PROGRESS_POP: - prefix = "" - offset = "pop" + prefix = self._empty_bytes + offset = self._progress_pop else: - prefix = "" - self._stream.write("progress: %s%s\n" % (prefix, offset)) + prefix = self._empty_bytes + offset = _b(str(offset)) + self._stream.write(self._progress_fmt + prefix + offset + + self._bytes_eol) def time(self, a_datetime): """Inform the client of the time. @@ -707,29 +721,29 @@ class TestProtocolClient(testresult.TestResult): ":param datetime: A datetime.datetime object. """ time = a_datetime.astimezone(iso8601.Utc()) - self._stream.write("time: %04d-%02d-%02d %02d:%02d:%02d.%06dZ\n" % ( + self._stream.write(_b("time: %04d-%02d-%02d %02d:%02d:%02d.%06dZ\n" % ( time.year, time.month, time.day, time.hour, time.minute, - time.second, time.microsecond)) + time.second, time.microsecond))) def _write_details(self, details): """Output details to the stream. :param details: An extended details dict for a test outcome. """ - self._stream.write(" [ multipart\n") + self._stream.write(_b(" [ multipart\n")) for name, content in sorted(details.items()): - self._stream.write("Content-Type: %s/%s" % - (content.content_type.type, content.content_type.subtype)) + self._stream.write(_b("Content-Type: %s/%s" % + (content.content_type.type, content.content_type.subtype))) parameters = content.content_type.parameters if parameters: - self._stream.write(";") + self._stream.write(_b(";")) param_strs = [] for param, value in parameters.items(): param_strs.append("%s=%s" % (param, value)) - self._stream.write(",".join(param_strs)) - self._stream.write("\n%s\n" % name) + self._stream.write(_b(",".join(param_strs))) + self._stream.write(_b("\n%s\n" % name)) encoder = chunked.Encoder(self._stream) - map(encoder.write, content.iter_bytes()) + list(map(encoder.write, content.iter_bytes())) encoder.close() def done(self): diff --git a/python/subunit/chunked.py b/python/subunit/chunked.py index 5f8c6f1..b992129 100644 --- a/python/subunit/chunked.py +++ b/python/subunit/chunked.py @@ -17,6 +17,10 @@ """Encoder/decoder for http style chunked encoding.""" +from testtools.compat import _b + +empty = _b('') + class Decoder(object): """Decode chunked content to a byte stream.""" @@ -25,11 +29,11 @@ class Decoder(object): :param output: A file-like object. Bytes written to the Decoder are decoded to strip off the chunking and written to the output. - Up to a full write worth of data or a single control line may be + Up to a full write worth of data or a single control line may be buffered (whichever is larger). The close method should be called when no more data is available, to detect short streams; the write method will return none-None when the end of a stream is - detected. + detected. The output object must accept bytes objects. :param strict: If True (the default), the decoder will not knowingly accept input that is not conformant to the HTTP specification. @@ -42,6 +46,11 @@ class Decoder(object): self.state = self._read_length self.body_length = 0 self.strict = strict + self._match_chars = _b("0123456789abcdefABCDEF\r\n") + self._slash_n = _b('\n') + self._slash_r = _b('\r') + self._slash_rn = _b('\r\n') + self._slash_nr = _b('\n\r') def close(self): """Close the decoder. @@ -56,7 +65,7 @@ class Decoder(object): if self.buffered_bytes: buffered_bytes = self.buffered_bytes self.buffered_bytes = [] - return ''.join(buffered_bytes) + return empty.join(buffered_bytes) else: raise ValueError("stream is finished") @@ -80,26 +89,26 @@ class Decoder(object): def _read_length(self): """Try to decode a length from the bytes.""" - match_chars = "0123456789abcdefABCDEF\r\n" count_chars = [] for bytes in self.buffered_bytes: - for byte in bytes: - if byte not in match_chars: + for pos in range(len(bytes)): + byte = bytes[pos:pos+1] + if byte not in self._match_chars: break count_chars.append(byte) - if byte == '\n': + if byte == self._slash_n: break if not count_chars: return - if count_chars[-1][-1] != '\n': + if count_chars[-1] != self._slash_n: return - count_str = ''.join(count_chars) + count_str = empty.join(count_chars) if self.strict: - if count_str[-2:] != '\r\n': + if count_str[-2:] != self._slash_rn: raise ValueError("chunk header invalid: %r" % count_str) - if '\r' in count_str[:-2]: + if self._slash_r in count_str[:-2]: raise ValueError("too many CRs in chunk header %r" % count_str) - self.body_length = int(count_str.rstrip('\n\r'), 16) + self.body_length = int(count_str.rstrip(self._slash_nr), 16) excess_bytes = len(count_str) while excess_bytes: if excess_bytes >= len(self.buffered_bytes[0]): @@ -112,7 +121,7 @@ class Decoder(object): self.state = self._finished if not self.buffered_bytes: # May not call into self._finished with no buffered data. - return '' + return empty else: self.state = self._read_body return self.state() @@ -155,9 +164,9 @@ class Encoder(object): buffer_size = self.buffer_size self.buffered_bytes = [] self.buffer_size = 0 - self.output.write("%X\r\n" % (buffer_size + extra_len)) + self.output.write(_b("%X\r\n" % (buffer_size + extra_len))) if buffer_size: - self.output.write(''.join(buffered_bytes)) + self.output.write(empty.join(buffered_bytes)) return True def write(self, bytes): @@ -173,4 +182,4 @@ class Encoder(object): def close(self): """Finish the stream. This does not close the output stream.""" self.flush() - self.output.write("0\r\n") + self.output.write(_b("0\r\n")) diff --git a/python/subunit/details.py b/python/subunit/details.py index 9790543..35bc88e 100644 --- a/python/subunit/details.py +++ b/python/subunit/details.py @@ -17,12 +17,13 @@ """Handlers for outcome details.""" from testtools import content, content_type -from testtools.compat import _b, StringIO +from testtools.compat import _b, BytesIO from subunit import chunked end_marker = _b("]\n") quoted_marker = _b(" ]") +empty = _b('') class DetailsParser(object): @@ -79,18 +80,18 @@ class MultipartDetailsParser(DetailsParser): self._parse_state = self._look_for_content def _look_for_content(self, line): - if line == "]\n": + if line == end_marker: self._state.endDetails() return # TODO error handling - field, value = line[:-1].split(' ', 1) + field, value = line[:-1].decode('utf8').split(' ', 1) main, sub = value.split('/') self._content_type = content_type.ContentType(main, sub) self._parse_state = self._get_name def _get_name(self, line): - self._name = line[:-1] - self._body = StringIO() + self._name = line[:-1].decode('utf8') + self._body = BytesIO() self._chunk_parser = chunked.Decoder(self._body) self._parse_state = self._feed_chunks @@ -98,7 +99,7 @@ class MultipartDetailsParser(DetailsParser): residue = self._chunk_parser.write(line) if residue is not None: # Line based use always ends on no residue. - assert residue == '', 'residue: %r' % (residue,) + assert residue == empty, 'residue: %r' % (residue,) body = self._body self._details[self._name] = content.Content( self._content_type, lambda:[body.getvalue()]) diff --git a/python/subunit/iso8601.py b/python/subunit/iso8601.py index 93c92fb..cbe9a3b 100644 --- a/python/subunit/iso8601.py +++ b/python/subunit/iso8601.py @@ -31,15 +31,25 @@ datetime.datetime(2007, 1, 25, 12, 0, tzinfo=<iso8601.iso8601.Utc ...>) from datetime import datetime, timedelta, tzinfo import re +import sys __all__ = ["parse_date", "ParseError"] # Adapted from http://delete.me.uk/2005/03/iso8601.html -ISO8601_REGEX = re.compile(r"(?P<year>[0-9]{4})(-(?P<month>[0-9]{1,2})(-(?P<day>[0-9]{1,2})" +ISO8601_REGEX_PATTERN = (r"(?P<year>[0-9]{4})(-(?P<month>[0-9]{1,2})(-(?P<day>[0-9]{1,2})" r"((?P<separator>.)(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2})(:(?P<second>[0-9]{2})(\.(?P<fraction>[0-9]+))?)?" r"(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))?)?)?)?" ) -TIMEZONE_REGEX = re.compile("(?P<prefix>[+-])(?P<hours>[0-9]{2}).(?P<minutes>[0-9]{2})") +TIMEZONE_REGEX_PATTERN = "(?P<prefix>[+-])(?P<hours>[0-9]{2}).(?P<minutes>[0-9]{2})" +ISO8601_REGEX = re.compile(ISO8601_REGEX_PATTERN.encode('utf8')) +TIMEZONE_REGEX = re.compile(TIMEZONE_REGEX_PATTERN.encode('utf8')) + +zulu = "Z".encode('latin-1') +minus = "-".encode('latin-1') + +if sys.version_info < (3, 0): + bytes = str + class ParseError(Exception): """Raised when there is a problem parsing a date string""" @@ -84,7 +94,7 @@ def parse_timezone(tzstring, default_timezone=UTC): """Parses ISO 8601 time zone specs into tzinfo offsets """ - if tzstring == "Z": + if tzstring == zulu: return default_timezone # This isn't strictly correct, but it's common to encounter dates without # timezones so I'll assume the default (which defaults to UTC). @@ -94,7 +104,7 @@ def parse_timezone(tzstring, default_timezone=UTC): m = TIMEZONE_REGEX.match(tzstring) prefix, hours, minutes = m.groups() hours, minutes = int(hours), int(minutes) - if prefix == "-": + if prefix == minus: hours = -hours minutes = -minutes return FixedOffset(hours, minutes, tzstring) @@ -107,8 +117,8 @@ def parse_date(datestring, default_timezone=UTC): default timezone specified in default_timezone is used. This is UTC by default. """ - if not isinstance(datestring, basestring): - raise ParseError("Expecting a string %r" % datestring) + if not isinstance(datestring, bytes): + raise ParseError("Expecting bytes %r" % datestring) m = ISO8601_REGEX.match(datestring) if not m: raise ParseError("Unable to parse date string %r" % datestring) diff --git a/python/subunit/tests/test_chunked.py b/python/subunit/tests/test_chunked.py index 6323b02..e0742f1 100644 --- a/python/subunit/tests/test_chunked.py +++ b/python/subunit/tests/test_chunked.py @@ -17,7 +17,7 @@ import unittest -from testtools.compat import StringIO +from testtools.compat import _b, BytesIO import subunit.chunked @@ -32,121 +32,121 @@ class TestDecode(unittest.TestCase): def setUp(self): unittest.TestCase.setUp(self) - self.output = StringIO() + self.output = BytesIO() self.decoder = subunit.chunked.Decoder(self.output) def test_close_read_length_short_errors(self): self.assertRaises(ValueError, self.decoder.close) def test_close_body_short_errors(self): - self.assertEqual(None, self.decoder.write('2\r\na')) + self.assertEqual(None, self.decoder.write(_b('2\r\na'))) self.assertRaises(ValueError, self.decoder.close) def test_close_body_buffered_data_errors(self): - self.assertEqual(None, self.decoder.write('2\r')) + self.assertEqual(None, self.decoder.write(_b('2\r'))) self.assertRaises(ValueError, self.decoder.close) def test_close_after_finished_stream_safe(self): - self.assertEqual(None, self.decoder.write('2\r\nab')) - self.assertEqual('', self.decoder.write('0\r\n')) + self.assertEqual(None, self.decoder.write(_b('2\r\nab'))) + self.assertEqual(_b(''), self.decoder.write(_b('0\r\n'))) self.decoder.close() def test_decode_nothing(self): - self.assertEqual('', self.decoder.write('0\r\n')) - self.assertEqual('', self.output.getvalue()) + self.assertEqual(_b(''), self.decoder.write(_b('0\r\n'))) + self.assertEqual(_b(''), self.output.getvalue()) def test_decode_serialised_form(self): - self.assertEqual(None, self.decoder.write("F\r\n")) - self.assertEqual(None, self.decoder.write("serialised\n")) - self.assertEqual('', self.decoder.write("form0\r\n")) + self.assertEqual(None, self.decoder.write(_b("F\r\n"))) + self.assertEqual(None, self.decoder.write(_b("serialised\n"))) + self.assertEqual(_b(''), self.decoder.write(_b("form0\r\n"))) def test_decode_short(self): - self.assertEqual('', self.decoder.write('3\r\nabc0\r\n')) - self.assertEqual('abc', self.output.getvalue()) + self.assertEqual(_b(''), self.decoder.write(_b('3\r\nabc0\r\n'))) + self.assertEqual(_b('abc'), self.output.getvalue()) def test_decode_combines_short(self): - self.assertEqual('', self.decoder.write('6\r\nabcdef0\r\n')) - self.assertEqual('abcdef', self.output.getvalue()) + self.assertEqual(_b(''), self.decoder.write(_b('6\r\nabcdef0\r\n'))) + self.assertEqual(_b('abcdef'), self.output.getvalue()) def test_decode_excess_bytes_from_write(self): - self.assertEqual('1234', self.decoder.write('3\r\nabc0\r\n1234')) - self.assertEqual('abc', self.output.getvalue()) + self.assertEqual(_b('1234'), self.decoder.write(_b('3\r\nabc0\r\n1234'))) + self.assertEqual(_b('abc'), self.output.getvalue()) def test_decode_write_after_finished_errors(self): - self.assertEqual('1234', self.decoder.write('3\r\nabc0\r\n1234')) - self.assertRaises(ValueError, self.decoder.write, '') + self.assertEqual(_b('1234'), self.decoder.write(_b('3\r\nabc0\r\n1234'))) + self.assertRaises(ValueError, self.decoder.write, _b('')) def test_decode_hex(self): - self.assertEqual('', self.decoder.write('A\r\n12345678900\r\n')) - self.assertEqual('1234567890', self.output.getvalue()) + self.assertEqual(_b(''), self.decoder.write(_b('A\r\n12345678900\r\n'))) + self.assertEqual(_b('1234567890'), self.output.getvalue()) def test_decode_long_ranges(self): - self.assertEqual(None, self.decoder.write('10000\r\n')) - self.assertEqual(None, self.decoder.write('1' * 65536)) - self.assertEqual(None, self.decoder.write('10000\r\n')) - self.assertEqual(None, self.decoder.write('2' * 65536)) - self.assertEqual('', self.decoder.write('0\r\n')) - self.assertEqual('1' * 65536 + '2' * 65536, self.output.getvalue()) + self.assertEqual(None, self.decoder.write(_b('10000\r\n'))) + self.assertEqual(None, self.decoder.write(_b('1' * 65536))) + self.assertEqual(None, self.decoder.write(_b('10000\r\n'))) + self.assertEqual(None, self.decoder.write(_b('2' * 65536))) + self.assertEqual(_b(''), self.decoder.write(_b('0\r\n'))) + self.assertEqual(_b('1' * 65536 + '2' * 65536), self.output.getvalue()) def test_decode_newline_nonstrict(self): """Tolerate chunk markers with no CR character.""" # From <http://pad.lv/505078> self.decoder = subunit.chunked.Decoder(self.output, strict=False) - self.assertEqual(None, self.decoder.write('a\n')) - self.assertEqual(None, self.decoder.write('abcdeabcde')) - self.assertEqual('', self.decoder.write('0\n')) - self.assertEqual('abcdeabcde', self.output.getvalue()) + self.assertEqual(None, self.decoder.write(_b('a\n'))) + self.assertEqual(None, self.decoder.write(_b('abcdeabcde'))) + self.assertEqual(_b(''), self.decoder.write(_b('0\n'))) + self.assertEqual(_b('abcdeabcde'), self.output.getvalue()) def test_decode_strict_newline_only(self): """Reject chunk markers with no CR character in strict mode.""" # From <http://pad.lv/505078> self.assertRaises(ValueError, - self.decoder.write, 'a\n') + self.decoder.write, _b('a\n')) def test_decode_strict_multiple_crs(self): self.assertRaises(ValueError, - self.decoder.write, 'a\r\r\n') + self.decoder.write, _b('a\r\r\n')) def test_decode_short_header(self): self.assertRaises(ValueError, - self.decoder.write, '\n') + self.decoder.write, _b('\n')) class TestEncode(unittest.TestCase): def setUp(self): unittest.TestCase.setUp(self) - self.output = StringIO() + self.output = BytesIO() self.encoder = subunit.chunked.Encoder(self.output) def test_encode_nothing(self): self.encoder.close() - self.assertEqual('0\r\n', self.output.getvalue()) + self.assertEqual(_b('0\r\n'), self.output.getvalue()) def test_encode_empty(self): - self.encoder.write('') + self.encoder.write(_b('')) self.encoder.close() - self.assertEqual('0\r\n', self.output.getvalue()) + self.assertEqual(_b('0\r\n'), self.output.getvalue()) def test_encode_short(self): - self.encoder.write('abc') + self.encoder.write(_b('abc')) self.encoder.close() - self.assertEqual('3\r\nabc0\r\n', self.output.getvalue()) + self.assertEqual(_b('3\r\nabc0\r\n'), self.output.getvalue()) def test_encode_combines_short(self): - self.encoder.write('abc') - self.encoder.write('def') + self.encoder.write(_b('abc')) + self.encoder.write(_b('def')) self.encoder.close() - self.assertEqual('6\r\nabcdef0\r\n', self.output.getvalue()) + self.assertEqual(_b('6\r\nabcdef0\r\n'), self.output.getvalue()) def test_encode_over_9_is_in_hex(self): - self.encoder.write('1234567890') + self.encoder.write(_b('1234567890')) self.encoder.close() - self.assertEqual('A\r\n12345678900\r\n', self.output.getvalue()) + self.assertEqual(_b('A\r\n12345678900\r\n'), self.output.getvalue()) def test_encode_long_ranges_not_combined(self): - self.encoder.write('1' * 65536) - self.encoder.write('2' * 65536) + self.encoder.write(_b('1' * 65536)) + self.encoder.write(_b('2' * 65536)) self.encoder.close() - self.assertEqual('10000\r\n' + '1' * 65536 + '10000\r\n' + - '2' * 65536 + '0\r\n', self.output.getvalue()) + self.assertEqual(_b('10000\r\n' + '1' * 65536 + '10000\r\n' + + '2' * 65536 + '0\r\n'), self.output.getvalue()) diff --git a/python/subunit/tests/test_details.py b/python/subunit/tests/test_details.py index 49010d2..746aa04 100644 --- a/python/subunit/tests/test_details.py +++ b/python/subunit/tests/test_details.py @@ -95,18 +95,18 @@ class TestMultipartDetails(unittest.TestCase): def test_parts(self): parser = details.MultipartDetailsParser(None) - parser.lineReceived("Content-Type: text/plain\n") - parser.lineReceived("something\n") - parser.lineReceived("F\r\n") - parser.lineReceived("serialised\n") - parser.lineReceived("form0\r\n") + parser.lineReceived(_b("Content-Type: text/plain\n")) + parser.lineReceived(_b("something\n")) + parser.lineReceived(_b("F\r\n")) + parser.lineReceived(_b("serialised\n")) + parser.lineReceived(_b("form0\r\n")) expected = {} expected['something'] = content.Content( content_type.ContentType("text", "plain"), - lambda:["serialised\nform"]) + lambda:[_b("serialised\nform")]) found = parser.get_details() self.assertEqual(expected.keys(), found.keys()) self.assertEqual(expected['something'].content_type, found['something'].content_type) - self.assertEqual(''.join(expected['something'].iter_bytes()), - ''.join(found['something'].iter_bytes())) + self.assertEqual(_b('').join(expected['something'].iter_bytes()), + _b('').join(found['something'].iter_bytes())) diff --git a/python/subunit/tests/test_subunit_filter.py b/python/subunit/tests/test_subunit_filter.py index 682f726..fb6ffcd 100644 --- a/python/subunit/tests/test_subunit_filter.py +++ b/python/subunit/tests/test_subunit_filter.py @@ -21,7 +21,7 @@ from subunit import iso8601 import unittest from testtools import TestCase -from testtools.compat import StringIO +from testtools.compat import _b, BytesIO, StringIO from testtools.testresult.doubles import ExtendedTestResult import subunit @@ -35,7 +35,7 @@ class TestTestResultFilter(TestCase): # is an easy pithy way of getting a series of test objects to call into # the TestResult, and as TestResultFilter is intended for use with subunit # also has the benefit of detecting any interface skew issues. - example_subunit_stream = """\ + example_subunit_stream = _b("""\ tags: global test passed success passed @@ -50,7 +50,7 @@ test skipped skip skipped test todo xfail todo -""" +""") def run_tests(self, result_filter, input_stream=None): """Run tests through the given filter. @@ -61,7 +61,7 @@ xfail todo """ if input_stream is None: input_stream = self.example_subunit_stream - test = subunit.ProtocolTestCase(StringIO(input_stream)) + test = subunit.ProtocolTestCase(BytesIO(input_stream)) test.run(result_filter) def test_default(self): @@ -139,13 +139,13 @@ xfail todo date_a = datetime(year=2000, month=1, day=1, tzinfo=iso8601.UTC) date_b = datetime(year=2000, month=1, day=2, tzinfo=iso8601.UTC) date_c = datetime(year=2000, month=1, day=3, tzinfo=iso8601.UTC) - subunit_stream = '\n'.join([ + subunit_stream = _b('\n'.join([ "time: %s", "test: foo", "time: %s", "error: foo", "time: %s", - ""]) % (date_a, date_b, date_c) + ""]) % (date_a, date_b, date_c)) result = ExtendedTestResult() result_filter = TestResultFilter(result) self.run_tests(result_filter, subunit_stream) diff --git a/python/subunit/tests/test_test_protocol.py b/python/subunit/tests/test_test_protocol.py index 7778fcc..7ec7758 100644 --- a/python/subunit/tests/test_test_protocol.py +++ b/python/subunit/tests/test_test_protocol.py @@ -28,7 +28,7 @@ from testtools.tests.helpers import ( ) import subunit -from subunit import _remote_exception_str +from subunit import _remote_exception_str, _remote_exception_str_chunked import subunit.iso8601 as iso8601 @@ -994,11 +994,11 @@ class TestIsolatedTestSuite(unittest.TestCase): class TestTestProtocolClient(unittest.TestCase): def setUp(self): - self.io = StringIO() + self.io = BytesIO() self.protocol = subunit.TestProtocolClient(self.io) self.test = TestTestProtocolClient("test_start_test") self.sample_details = {'something':Content( - ContentType('text', 'plain'), lambda:['serialised\nform'])} + ContentType('text', 'plain'), lambda:[_b('serialised\nform')])} self.sample_tb_details = dict(self.sample_details) self.sample_tb_details['traceback'] = TracebackContent( subunit.RemoteError(_u("boo qux")), self.test) @@ -1006,27 +1006,27 @@ class TestTestProtocolClient(unittest.TestCase): def test_start_test(self): """Test startTest on a TestProtocolClient.""" self.protocol.startTest(self.test) - self.assertEqual(self.io.getvalue(), "test: %s\n" % self.test.id()) + self.assertEqual(self.io.getvalue(), _b("test: %s\n" % self.test.id())) def test_stop_test(self): # stopTest doesn't output anything. self.protocol.stopTest(self.test) - self.assertEqual(self.io.getvalue(), "") + self.assertEqual(self.io.getvalue(), _b("")) def test_add_success(self): """Test addSuccess on a TestProtocolClient.""" self.protocol.addSuccess(self.test) self.assertEqual( - self.io.getvalue(), "successful: %s\n" % self.test.id()) + self.io.getvalue(), _b("successful: %s\n" % self.test.id())) def test_add_success_details(self): """Test addSuccess on a TestProtocolClient with details.""" self.protocol.addSuccess(self.test, details=self.sample_details) self.assertEqual( - self.io.getvalue(), "successful: %s [ multipart\n" + self.io.getvalue(), _b("successful: %s [ multipart\n" "Content-Type: text/plain\n" "something\n" - "F\r\nserialised\nform0\r\n]\n" % self.test.id()) + "F\r\nserialised\nform0\r\n]\n" % self.test.id())) def test_add_failure(self): """Test addFailure on a TestProtocolClient.""" @@ -1034,8 +1034,8 @@ class TestTestProtocolClient(unittest.TestCase): self.test, subunit.RemoteError(_u("boo qux"))) self.assertEqual( self.io.getvalue(), - ('failure: %s [\n' + _remote_exception_str + ': boo qux\n]\n') - % self.test.id()) + _b(('failure: %s [\n' + _remote_exception_str + ': boo qux\n]\n') + % self.test.id())) def test_add_failure_details(self): """Test addFailure on a TestProtocolClient with details.""" @@ -1043,14 +1043,13 @@ class TestTestProtocolClient(unittest.TestCase): self.test, details=self.sample_tb_details) self.assertEqual( self.io.getvalue(), - ("failure: %s [ multipart\n" + _b(("failure: %s [ multipart\n" "Content-Type: text/plain\n" "something\n" "F\r\nserialised\nform0\r\n" "Content-Type: text/x-traceback;charset=utf8,language=python\n" - "traceback\n" - "1A\r\n" + _remote_exception_str + ": boo qux\n0\r\n" - "]\n") % self.test.id()) + "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n" + "]\n") % self.test.id())) def test_add_error(self): """Test stopTest on a TestProtocolClient.""" @@ -1058,9 +1057,9 @@ class TestTestProtocolClient(unittest.TestCase): self.test, subunit.RemoteError(_u("phwoar crikey"))) self.assertEqual( self.io.getvalue(), - ('error: %s [\n' + + _b(('error: %s [\n' + _remote_exception_str + ": phwoar crikey\n" - "]\n") % self.test.id()) + "]\n") % self.test.id())) def test_add_error_details(self): """Test stopTest on a TestProtocolClient with details.""" @@ -1068,14 +1067,13 @@ class TestTestProtocolClient(unittest.TestCase): self.test, details=self.sample_tb_details) self.assertEqual( self.io.getvalue(), - ("error: %s [ multipart\n" + _b(("error: %s [ multipart\n" "Content-Type: text/plain\n" "something\n" "F\r\nserialised\nform0\r\n" "Content-Type: text/x-traceback;charset=utf8,language=python\n" - "traceback\n" - "1A\r\n" + _remote_exception_str + ": boo qux\n0\r\n" - "]\n") % self.test.id()) + "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n" + "]\n") % self.test.id())) def test_add_expected_failure(self): """Test addExpectedFailure on a TestProtocolClient.""" @@ -1083,9 +1081,9 @@ class TestTestProtocolClient(unittest.TestCase): self.test, subunit.RemoteError(_u("phwoar crikey"))) self.assertEqual( self.io.getvalue(), - ('xfail: %s [\n' + + _b(('xfail: %s [\n' + _remote_exception_str + ": phwoar crikey\n" - "]\n") % self.test.id()) + "]\n") % self.test.id())) def test_add_expected_failure_details(self): """Test addExpectedFailure on a TestProtocolClient with details.""" @@ -1093,14 +1091,14 @@ class TestTestProtocolClient(unittest.TestCase): self.test, details=self.sample_tb_details) self.assertEqual( self.io.getvalue(), - ("xfail: %s [ multipart\n" + _b(("xfail: %s [ multipart\n" "Content-Type: text/plain\n" "something\n" "F\r\nserialised\nform0\r\n" "Content-Type: text/x-traceback;charset=utf8,language=python\n" - "traceback\n" - "1A\r\n"+ _remote_exception_str + ": boo qux\n0\r\n" - "]\n") % self.test.id()) + "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n" + "]\n") % self.test.id())) + def test_add_skip(self): """Test addSkip on a TestProtocolClient.""" @@ -1108,64 +1106,63 @@ class TestTestProtocolClient(unittest.TestCase): self.test, "Has it really?") self.assertEqual( self.io.getvalue(), - 'skip: %s [\nHas it really?\n]\n' % self.test.id()) + _b('skip: %s [\nHas it really?\n]\n' % self.test.id())) def test_add_skip_details(self): """Test addSkip on a TestProtocolClient with details.""" details = {'reason':Content( - ContentType('text', 'plain'), lambda:['Has it really?'])} - self.protocol.addSkip( - self.test, details=details) + ContentType('text', 'plain'), lambda:[_b('Has it really?')])} + self.protocol.addSkip(self.test, details=details) self.assertEqual( self.io.getvalue(), - "skip: %s [ multipart\n" + _b("skip: %s [ multipart\n" "Content-Type: text/plain\n" "reason\n" "E\r\nHas it really?0\r\n" - "]\n" % self.test.id()) + "]\n" % self.test.id())) def test_progress_set(self): self.protocol.progress(23, subunit.PROGRESS_SET) - self.assertEqual(self.io.getvalue(), 'progress: 23\n') + self.assertEqual(self.io.getvalue(), _b('progress: 23\n')) def test_progress_neg_cur(self): self.protocol.progress(-23, subunit.PROGRESS_CUR) - self.assertEqual(self.io.getvalue(), 'progress: -23\n') + self.assertEqual(self.io.getvalue(), _b('progress: -23\n')) def test_progress_pos_cur(self): self.protocol.progress(23, subunit.PROGRESS_CUR) - self.assertEqual(self.io.getvalue(), 'progress: +23\n') + self.assertEqual(self.io.getvalue(), _b('progress: +23\n')) def test_progress_pop(self): self.protocol.progress(1234, subunit.PROGRESS_POP) - self.assertEqual(self.io.getvalue(), 'progress: pop\n') + self.assertEqual(self.io.getvalue(), _b('progress: pop\n')) def test_progress_push(self): self.protocol.progress(1234, subunit.PROGRESS_PUSH) - self.assertEqual(self.io.getvalue(), 'progress: push\n') + self.assertEqual(self.io.getvalue(), _b('progress: push\n')) def test_time(self): # Calling time() outputs a time signal immediately. self.protocol.time( datetime.datetime(2009,10,11,12,13,14,15, iso8601.Utc())) self.assertEqual( - "time: 2009-10-11 12:13:14.000015Z\n", + _b("time: 2009-10-11 12:13:14.000015Z\n"), self.io.getvalue()) def test_add_unexpected_success(self): """Test addUnexpectedSuccess on a TestProtocolClient.""" self.protocol.addUnexpectedSuccess(self.test) self.assertEqual( - self.io.getvalue(), "successful: %s\n" % self.test.id()) + self.io.getvalue(), _b("successful: %s\n" % self.test.id())) def test_add_unexpected_success_details(self): """Test addUnexpectedSuccess on a TestProtocolClient with details.""" self.protocol.addUnexpectedSuccess(self.test, details=self.sample_details) self.assertEqual( - self.io.getvalue(), "successful: %s [ multipart\n" + self.io.getvalue(), _b("successful: %s [ multipart\n" "Content-Type: text/plain\n" "something\n" - "F\r\nserialised\nform0\r\n]\n" % self.test.id()) + "F\r\nserialised\nform0\r\n]\n" % self.test.id())) def test_suite(): |
