summaryrefslogtreecommitdiff
path: root/python/subunit
diff options
context:
space:
mode:
authorRobert Collins <robertc@robertcollins.net>2011-04-25 16:25:24 +1200
committerRobert Collins <robertc@robertcollins.net>2011-04-25 16:25:24 +1200
commit3b6f97b3c5ac79aefc74047b9292c4acee7ff09d (patch)
tree10dbc28291a3d86f66b77bb5aaa58dc79899f162 /python/subunit
parent63d596a30103017c7cac6af37d46c91a9cf7c325 (diff)
downloadsubunit-3b6f97b3c5ac79aefc74047b9292c4acee7ff09d.tar.gz
Nearly done.
Diffstat (limited to 'python/subunit')
-rw-r--r--python/subunit/__init__.py72
-rw-r--r--python/subunit/chunked.py41
-rw-r--r--python/subunit/details.py13
-rw-r--r--python/subunit/iso8601.py22
-rw-r--r--python/subunit/tests/test_chunked.py98
-rw-r--r--python/subunit/tests/test_details.py16
-rw-r--r--python/subunit/tests/test_subunit_filter.py12
-rw-r--r--python/subunit/tests/test_test_protocol.py79
8 files changed, 192 insertions, 161 deletions
diff --git a/python/subunit/__init__.py b/python/subunit/__init__.py
index 6e1668e..8fb3ab7 100644
--- a/python/subunit/__init__.py
+++ b/python/subunit/__init__.py
@@ -129,9 +129,11 @@ try:
RemoteException = _StringException
# For testing: different pythons have different str() implementations.
if sys.version_info > (3, 0):
- _remote_exception_str = 'testtools.testresult.real._StringException'
+ _remote_exception_str = "testtools.testresult.real._StringException"
+ _remote_exception_str_chunked = "34\r\n" + _remote_exception_str
else:
- _remote_exception_str = '_StringException'
+ _remote_exception_str = "_StringException"
+ _remote_exception_str_chunked = "1A\r\n" + _remote_exception_str
except ImportError:
raise ImportError ("testtools.testresult.real does not contain "
"_StringException, check your version.")
@@ -559,7 +561,8 @@ class TestProtocolClient(testresult.TestResult):
# Get a TestSuite or TestCase to run
suite = make_suite()
- # Create a stream (any object with a 'write' method)
+ # Create a stream (any object with a 'write' method). This should accept
+ # bytes not strings: subunit is a byte orientated protocol.
stream = file('tests.log', 'wb')
# Create a subunit result object which will output to the stream
result = subunit.TestProtocolClient(stream)
@@ -576,6 +579,14 @@ class TestProtocolClient(testresult.TestResult):
testresult.TestResult.__init__(self)
self._stream = stream
_make_stream_binary(stream)
+ self._progress_fmt = _b("progress: ")
+ self._bytes_eol = _b("\n")
+ self._progress_plus = _b("+")
+ self._progress_push = _b("push")
+ self._progress_pop = _b("pop")
+ self._empty_bytes = _b("")
+ self._start_simple = _b(" [\n")
+ self._end_simple = _b("]\n")
def addError(self, test, error=None, details=None):
"""Report an error in test test.
@@ -637,42 +648,42 @@ class TestProtocolClient(testresult.TestResult):
:param details: New Testing-in-python drafted API; a dict from string
to subunit.Content objects.
"""
- self._stream.write("%s: %s" % (outcome, test.id()))
+ self._stream.write(_b("%s: %s" % (outcome, test.id())))
if error is None and details is None:
raise ValueError
if error is not None:
- self._stream.write(" [\n")
+ self._stream.write(self._start_simple)
# XXX: this needs to be made much stricter, along the lines of
# Martin[gz]'s work in testtools. Perhaps subunit can use that?
for line in self._exc_info_to_unicode(error, test).splitlines():
self._stream.write(("%s\n" % line).encode('utf8'))
else:
self._write_details(details)
- self._stream.write("]\n")
+ self._stream.write(self._end_simple)
def addSkip(self, test, reason=None, details=None):
"""Report a skipped test."""
if reason is None:
self._addOutcome("skip", test, error=None, details=details)
else:
- self._stream.write("skip: %s [\n" % test.id())
- self._stream.write("%s\n" % reason)
- self._stream.write("]\n")
+ self._stream.write(_b("skip: %s [\n" % test.id()))
+ self._stream.write(_b("%s\n" % reason))
+ self._stream.write(self._end_simple)
def addSuccess(self, test, details=None):
"""Report a success in a test."""
- self._stream.write("successful: %s" % test.id())
+ self._stream.write(_b("successful: %s" % test.id()))
if not details:
- self._stream.write("\n")
+ self._stream.write(_b("\n"))
else:
self._write_details(details)
- self._stream.write("]\n")
+ self._stream.write(self._end_simple)
addUnexpectedSuccess = addSuccess
def startTest(self, test):
"""Mark a test as starting its test run."""
super(TestProtocolClient, self).startTest(test)
- self._stream.write("test: %s\n" % test.id())
+ self._stream.write(_b("test: %s\n" % test.id()))
self._stream.flush()
def stopTest(self, test):
@@ -690,16 +701,19 @@ class TestProtocolClient(testresult.TestResult):
PROGRESS_POP.
"""
if whence == PROGRESS_CUR and offset > -1:
- prefix = "+"
+ prefix = self._progress_plus
+ offset = _b(str(offset))
elif whence == PROGRESS_PUSH:
- prefix = ""
- offset = "push"
+ prefix = self._empty_bytes
+ offset = self._progress_push
elif whence == PROGRESS_POP:
- prefix = ""
- offset = "pop"
+ prefix = self._empty_bytes
+ offset = self._progress_pop
else:
- prefix = ""
- self._stream.write("progress: %s%s\n" % (prefix, offset))
+ prefix = self._empty_bytes
+ offset = _b(str(offset))
+ self._stream.write(self._progress_fmt + prefix + offset +
+ self._bytes_eol)
def time(self, a_datetime):
"""Inform the client of the time.
@@ -707,29 +721,29 @@ class TestProtocolClient(testresult.TestResult):
":param datetime: A datetime.datetime object.
"""
time = a_datetime.astimezone(iso8601.Utc())
- self._stream.write("time: %04d-%02d-%02d %02d:%02d:%02d.%06dZ\n" % (
+ self._stream.write(_b("time: %04d-%02d-%02d %02d:%02d:%02d.%06dZ\n" % (
time.year, time.month, time.day, time.hour, time.minute,
- time.second, time.microsecond))
+ time.second, time.microsecond)))
def _write_details(self, details):
"""Output details to the stream.
:param details: An extended details dict for a test outcome.
"""
- self._stream.write(" [ multipart\n")
+ self._stream.write(_b(" [ multipart\n"))
for name, content in sorted(details.items()):
- self._stream.write("Content-Type: %s/%s" %
- (content.content_type.type, content.content_type.subtype))
+ self._stream.write(_b("Content-Type: %s/%s" %
+ (content.content_type.type, content.content_type.subtype)))
parameters = content.content_type.parameters
if parameters:
- self._stream.write(";")
+ self._stream.write(_b(";"))
param_strs = []
for param, value in parameters.items():
param_strs.append("%s=%s" % (param, value))
- self._stream.write(",".join(param_strs))
- self._stream.write("\n%s\n" % name)
+ self._stream.write(_b(",".join(param_strs)))
+ self._stream.write(_b("\n%s\n" % name))
encoder = chunked.Encoder(self._stream)
- map(encoder.write, content.iter_bytes())
+ list(map(encoder.write, content.iter_bytes()))
encoder.close()
def done(self):
diff --git a/python/subunit/chunked.py b/python/subunit/chunked.py
index 5f8c6f1..b992129 100644
--- a/python/subunit/chunked.py
+++ b/python/subunit/chunked.py
@@ -17,6 +17,10 @@
"""Encoder/decoder for http style chunked encoding."""
+from testtools.compat import _b
+
+empty = _b('')
+
class Decoder(object):
"""Decode chunked content to a byte stream."""
@@ -25,11 +29,11 @@ class Decoder(object):
:param output: A file-like object. Bytes written to the Decoder are
decoded to strip off the chunking and written to the output.
- Up to a full write worth of data or a single control line may be
+ Up to a full write worth of data or a single control line may be
buffered (whichever is larger). The close method should be called
when no more data is available, to detect short streams; the
write method will return none-None when the end of a stream is
- detected.
+ detected. The output object must accept bytes objects.
:param strict: If True (the default), the decoder will not knowingly
accept input that is not conformant to the HTTP specification.
@@ -42,6 +46,11 @@ class Decoder(object):
self.state = self._read_length
self.body_length = 0
self.strict = strict
+ self._match_chars = _b("0123456789abcdefABCDEF\r\n")
+ self._slash_n = _b('\n')
+ self._slash_r = _b('\r')
+ self._slash_rn = _b('\r\n')
+ self._slash_nr = _b('\n\r')
def close(self):
"""Close the decoder.
@@ -56,7 +65,7 @@ class Decoder(object):
if self.buffered_bytes:
buffered_bytes = self.buffered_bytes
self.buffered_bytes = []
- return ''.join(buffered_bytes)
+ return empty.join(buffered_bytes)
else:
raise ValueError("stream is finished")
@@ -80,26 +89,26 @@ class Decoder(object):
def _read_length(self):
"""Try to decode a length from the bytes."""
- match_chars = "0123456789abcdefABCDEF\r\n"
count_chars = []
for bytes in self.buffered_bytes:
- for byte in bytes:
- if byte not in match_chars:
+ for pos in range(len(bytes)):
+ byte = bytes[pos:pos+1]
+ if byte not in self._match_chars:
break
count_chars.append(byte)
- if byte == '\n':
+ if byte == self._slash_n:
break
if not count_chars:
return
- if count_chars[-1][-1] != '\n':
+ if count_chars[-1] != self._slash_n:
return
- count_str = ''.join(count_chars)
+ count_str = empty.join(count_chars)
if self.strict:
- if count_str[-2:] != '\r\n':
+ if count_str[-2:] != self._slash_rn:
raise ValueError("chunk header invalid: %r" % count_str)
- if '\r' in count_str[:-2]:
+ if self._slash_r in count_str[:-2]:
raise ValueError("too many CRs in chunk header %r" % count_str)
- self.body_length = int(count_str.rstrip('\n\r'), 16)
+ self.body_length = int(count_str.rstrip(self._slash_nr), 16)
excess_bytes = len(count_str)
while excess_bytes:
if excess_bytes >= len(self.buffered_bytes[0]):
@@ -112,7 +121,7 @@ class Decoder(object):
self.state = self._finished
if not self.buffered_bytes:
# May not call into self._finished with no buffered data.
- return ''
+ return empty
else:
self.state = self._read_body
return self.state()
@@ -155,9 +164,9 @@ class Encoder(object):
buffer_size = self.buffer_size
self.buffered_bytes = []
self.buffer_size = 0
- self.output.write("%X\r\n" % (buffer_size + extra_len))
+ self.output.write(_b("%X\r\n" % (buffer_size + extra_len)))
if buffer_size:
- self.output.write(''.join(buffered_bytes))
+ self.output.write(empty.join(buffered_bytes))
return True
def write(self, bytes):
@@ -173,4 +182,4 @@ class Encoder(object):
def close(self):
"""Finish the stream. This does not close the output stream."""
self.flush()
- self.output.write("0\r\n")
+ self.output.write(_b("0\r\n"))
diff --git a/python/subunit/details.py b/python/subunit/details.py
index 9790543..35bc88e 100644
--- a/python/subunit/details.py
+++ b/python/subunit/details.py
@@ -17,12 +17,13 @@
"""Handlers for outcome details."""
from testtools import content, content_type
-from testtools.compat import _b, StringIO
+from testtools.compat import _b, BytesIO
from subunit import chunked
end_marker = _b("]\n")
quoted_marker = _b(" ]")
+empty = _b('')
class DetailsParser(object):
@@ -79,18 +80,18 @@ class MultipartDetailsParser(DetailsParser):
self._parse_state = self._look_for_content
def _look_for_content(self, line):
- if line == "]\n":
+ if line == end_marker:
self._state.endDetails()
return
# TODO error handling
- field, value = line[:-1].split(' ', 1)
+ field, value = line[:-1].decode('utf8').split(' ', 1)
main, sub = value.split('/')
self._content_type = content_type.ContentType(main, sub)
self._parse_state = self._get_name
def _get_name(self, line):
- self._name = line[:-1]
- self._body = StringIO()
+ self._name = line[:-1].decode('utf8')
+ self._body = BytesIO()
self._chunk_parser = chunked.Decoder(self._body)
self._parse_state = self._feed_chunks
@@ -98,7 +99,7 @@ class MultipartDetailsParser(DetailsParser):
residue = self._chunk_parser.write(line)
if residue is not None:
# Line based use always ends on no residue.
- assert residue == '', 'residue: %r' % (residue,)
+ assert residue == empty, 'residue: %r' % (residue,)
body = self._body
self._details[self._name] = content.Content(
self._content_type, lambda:[body.getvalue()])
diff --git a/python/subunit/iso8601.py b/python/subunit/iso8601.py
index 93c92fb..cbe9a3b 100644
--- a/python/subunit/iso8601.py
+++ b/python/subunit/iso8601.py
@@ -31,15 +31,25 @@ datetime.datetime(2007, 1, 25, 12, 0, tzinfo=<iso8601.iso8601.Utc ...>)
from datetime import datetime, timedelta, tzinfo
import re
+import sys
__all__ = ["parse_date", "ParseError"]
# Adapted from http://delete.me.uk/2005/03/iso8601.html
-ISO8601_REGEX = re.compile(r"(?P<year>[0-9]{4})(-(?P<month>[0-9]{1,2})(-(?P<day>[0-9]{1,2})"
+ISO8601_REGEX_PATTERN = (r"(?P<year>[0-9]{4})(-(?P<month>[0-9]{1,2})(-(?P<day>[0-9]{1,2})"
r"((?P<separator>.)(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2})(:(?P<second>[0-9]{2})(\.(?P<fraction>[0-9]+))?)?"
r"(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))?)?)?)?"
)
-TIMEZONE_REGEX = re.compile("(?P<prefix>[+-])(?P<hours>[0-9]{2}).(?P<minutes>[0-9]{2})")
+TIMEZONE_REGEX_PATTERN = "(?P<prefix>[+-])(?P<hours>[0-9]{2}).(?P<minutes>[0-9]{2})"
+ISO8601_REGEX = re.compile(ISO8601_REGEX_PATTERN.encode('utf8'))
+TIMEZONE_REGEX = re.compile(TIMEZONE_REGEX_PATTERN.encode('utf8'))
+
+zulu = "Z".encode('latin-1')
+minus = "-".encode('latin-1')
+
+if sys.version_info < (3, 0):
+ bytes = str
+
class ParseError(Exception):
"""Raised when there is a problem parsing a date string"""
@@ -84,7 +94,7 @@ def parse_timezone(tzstring, default_timezone=UTC):
"""Parses ISO 8601 time zone specs into tzinfo offsets
"""
- if tzstring == "Z":
+ if tzstring == zulu:
return default_timezone
# This isn't strictly correct, but it's common to encounter dates without
# timezones so I'll assume the default (which defaults to UTC).
@@ -94,7 +104,7 @@ def parse_timezone(tzstring, default_timezone=UTC):
m = TIMEZONE_REGEX.match(tzstring)
prefix, hours, minutes = m.groups()
hours, minutes = int(hours), int(minutes)
- if prefix == "-":
+ if prefix == minus:
hours = -hours
minutes = -minutes
return FixedOffset(hours, minutes, tzstring)
@@ -107,8 +117,8 @@ def parse_date(datestring, default_timezone=UTC):
default timezone specified in default_timezone is used. This is UTC by
default.
"""
- if not isinstance(datestring, basestring):
- raise ParseError("Expecting a string %r" % datestring)
+ if not isinstance(datestring, bytes):
+ raise ParseError("Expecting bytes %r" % datestring)
m = ISO8601_REGEX.match(datestring)
if not m:
raise ParseError("Unable to parse date string %r" % datestring)
diff --git a/python/subunit/tests/test_chunked.py b/python/subunit/tests/test_chunked.py
index 6323b02..e0742f1 100644
--- a/python/subunit/tests/test_chunked.py
+++ b/python/subunit/tests/test_chunked.py
@@ -17,7 +17,7 @@
import unittest
-from testtools.compat import StringIO
+from testtools.compat import _b, BytesIO
import subunit.chunked
@@ -32,121 +32,121 @@ class TestDecode(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
- self.output = StringIO()
+ self.output = BytesIO()
self.decoder = subunit.chunked.Decoder(self.output)
def test_close_read_length_short_errors(self):
self.assertRaises(ValueError, self.decoder.close)
def test_close_body_short_errors(self):
- self.assertEqual(None, self.decoder.write('2\r\na'))
+ self.assertEqual(None, self.decoder.write(_b('2\r\na')))
self.assertRaises(ValueError, self.decoder.close)
def test_close_body_buffered_data_errors(self):
- self.assertEqual(None, self.decoder.write('2\r'))
+ self.assertEqual(None, self.decoder.write(_b('2\r')))
self.assertRaises(ValueError, self.decoder.close)
def test_close_after_finished_stream_safe(self):
- self.assertEqual(None, self.decoder.write('2\r\nab'))
- self.assertEqual('', self.decoder.write('0\r\n'))
+ self.assertEqual(None, self.decoder.write(_b('2\r\nab')))
+ self.assertEqual(_b(''), self.decoder.write(_b('0\r\n')))
self.decoder.close()
def test_decode_nothing(self):
- self.assertEqual('', self.decoder.write('0\r\n'))
- self.assertEqual('', self.output.getvalue())
+ self.assertEqual(_b(''), self.decoder.write(_b('0\r\n')))
+ self.assertEqual(_b(''), self.output.getvalue())
def test_decode_serialised_form(self):
- self.assertEqual(None, self.decoder.write("F\r\n"))
- self.assertEqual(None, self.decoder.write("serialised\n"))
- self.assertEqual('', self.decoder.write("form0\r\n"))
+ self.assertEqual(None, self.decoder.write(_b("F\r\n")))
+ self.assertEqual(None, self.decoder.write(_b("serialised\n")))
+ self.assertEqual(_b(''), self.decoder.write(_b("form0\r\n")))
def test_decode_short(self):
- self.assertEqual('', self.decoder.write('3\r\nabc0\r\n'))
- self.assertEqual('abc', self.output.getvalue())
+ self.assertEqual(_b(''), self.decoder.write(_b('3\r\nabc0\r\n')))
+ self.assertEqual(_b('abc'), self.output.getvalue())
def test_decode_combines_short(self):
- self.assertEqual('', self.decoder.write('6\r\nabcdef0\r\n'))
- self.assertEqual('abcdef', self.output.getvalue())
+ self.assertEqual(_b(''), self.decoder.write(_b('6\r\nabcdef0\r\n')))
+ self.assertEqual(_b('abcdef'), self.output.getvalue())
def test_decode_excess_bytes_from_write(self):
- self.assertEqual('1234', self.decoder.write('3\r\nabc0\r\n1234'))
- self.assertEqual('abc', self.output.getvalue())
+ self.assertEqual(_b('1234'), self.decoder.write(_b('3\r\nabc0\r\n1234')))
+ self.assertEqual(_b('abc'), self.output.getvalue())
def test_decode_write_after_finished_errors(self):
- self.assertEqual('1234', self.decoder.write('3\r\nabc0\r\n1234'))
- self.assertRaises(ValueError, self.decoder.write, '')
+ self.assertEqual(_b('1234'), self.decoder.write(_b('3\r\nabc0\r\n1234')))
+ self.assertRaises(ValueError, self.decoder.write, _b(''))
def test_decode_hex(self):
- self.assertEqual('', self.decoder.write('A\r\n12345678900\r\n'))
- self.assertEqual('1234567890', self.output.getvalue())
+ self.assertEqual(_b(''), self.decoder.write(_b('A\r\n12345678900\r\n')))
+ self.assertEqual(_b('1234567890'), self.output.getvalue())
def test_decode_long_ranges(self):
- self.assertEqual(None, self.decoder.write('10000\r\n'))
- self.assertEqual(None, self.decoder.write('1' * 65536))
- self.assertEqual(None, self.decoder.write('10000\r\n'))
- self.assertEqual(None, self.decoder.write('2' * 65536))
- self.assertEqual('', self.decoder.write('0\r\n'))
- self.assertEqual('1' * 65536 + '2' * 65536, self.output.getvalue())
+ self.assertEqual(None, self.decoder.write(_b('10000\r\n')))
+ self.assertEqual(None, self.decoder.write(_b('1' * 65536)))
+ self.assertEqual(None, self.decoder.write(_b('10000\r\n')))
+ self.assertEqual(None, self.decoder.write(_b('2' * 65536)))
+ self.assertEqual(_b(''), self.decoder.write(_b('0\r\n')))
+ self.assertEqual(_b('1' * 65536 + '2' * 65536), self.output.getvalue())
def test_decode_newline_nonstrict(self):
"""Tolerate chunk markers with no CR character."""
# From <http://pad.lv/505078>
self.decoder = subunit.chunked.Decoder(self.output, strict=False)
- self.assertEqual(None, self.decoder.write('a\n'))
- self.assertEqual(None, self.decoder.write('abcdeabcde'))
- self.assertEqual('', self.decoder.write('0\n'))
- self.assertEqual('abcdeabcde', self.output.getvalue())
+ self.assertEqual(None, self.decoder.write(_b('a\n')))
+ self.assertEqual(None, self.decoder.write(_b('abcdeabcde')))
+ self.assertEqual(_b(''), self.decoder.write(_b('0\n')))
+ self.assertEqual(_b('abcdeabcde'), self.output.getvalue())
def test_decode_strict_newline_only(self):
"""Reject chunk markers with no CR character in strict mode."""
# From <http://pad.lv/505078>
self.assertRaises(ValueError,
- self.decoder.write, 'a\n')
+ self.decoder.write, _b('a\n'))
def test_decode_strict_multiple_crs(self):
self.assertRaises(ValueError,
- self.decoder.write, 'a\r\r\n')
+ self.decoder.write, _b('a\r\r\n'))
def test_decode_short_header(self):
self.assertRaises(ValueError,
- self.decoder.write, '\n')
+ self.decoder.write, _b('\n'))
class TestEncode(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
- self.output = StringIO()
+ self.output = BytesIO()
self.encoder = subunit.chunked.Encoder(self.output)
def test_encode_nothing(self):
self.encoder.close()
- self.assertEqual('0\r\n', self.output.getvalue())
+ self.assertEqual(_b('0\r\n'), self.output.getvalue())
def test_encode_empty(self):
- self.encoder.write('')
+ self.encoder.write(_b(''))
self.encoder.close()
- self.assertEqual('0\r\n', self.output.getvalue())
+ self.assertEqual(_b('0\r\n'), self.output.getvalue())
def test_encode_short(self):
- self.encoder.write('abc')
+ self.encoder.write(_b('abc'))
self.encoder.close()
- self.assertEqual('3\r\nabc0\r\n', self.output.getvalue())
+ self.assertEqual(_b('3\r\nabc0\r\n'), self.output.getvalue())
def test_encode_combines_short(self):
- self.encoder.write('abc')
- self.encoder.write('def')
+ self.encoder.write(_b('abc'))
+ self.encoder.write(_b('def'))
self.encoder.close()
- self.assertEqual('6\r\nabcdef0\r\n', self.output.getvalue())
+ self.assertEqual(_b('6\r\nabcdef0\r\n'), self.output.getvalue())
def test_encode_over_9_is_in_hex(self):
- self.encoder.write('1234567890')
+ self.encoder.write(_b('1234567890'))
self.encoder.close()
- self.assertEqual('A\r\n12345678900\r\n', self.output.getvalue())
+ self.assertEqual(_b('A\r\n12345678900\r\n'), self.output.getvalue())
def test_encode_long_ranges_not_combined(self):
- self.encoder.write('1' * 65536)
- self.encoder.write('2' * 65536)
+ self.encoder.write(_b('1' * 65536))
+ self.encoder.write(_b('2' * 65536))
self.encoder.close()
- self.assertEqual('10000\r\n' + '1' * 65536 + '10000\r\n' +
- '2' * 65536 + '0\r\n', self.output.getvalue())
+ self.assertEqual(_b('10000\r\n' + '1' * 65536 + '10000\r\n' +
+ '2' * 65536 + '0\r\n'), self.output.getvalue())
diff --git a/python/subunit/tests/test_details.py b/python/subunit/tests/test_details.py
index 49010d2..746aa04 100644
--- a/python/subunit/tests/test_details.py
+++ b/python/subunit/tests/test_details.py
@@ -95,18 +95,18 @@ class TestMultipartDetails(unittest.TestCase):
def test_parts(self):
parser = details.MultipartDetailsParser(None)
- parser.lineReceived("Content-Type: text/plain\n")
- parser.lineReceived("something\n")
- parser.lineReceived("F\r\n")
- parser.lineReceived("serialised\n")
- parser.lineReceived("form0\r\n")
+ parser.lineReceived(_b("Content-Type: text/plain\n"))
+ parser.lineReceived(_b("something\n"))
+ parser.lineReceived(_b("F\r\n"))
+ parser.lineReceived(_b("serialised\n"))
+ parser.lineReceived(_b("form0\r\n"))
expected = {}
expected['something'] = content.Content(
content_type.ContentType("text", "plain"),
- lambda:["serialised\nform"])
+ lambda:[_b("serialised\nform")])
found = parser.get_details()
self.assertEqual(expected.keys(), found.keys())
self.assertEqual(expected['something'].content_type,
found['something'].content_type)
- self.assertEqual(''.join(expected['something'].iter_bytes()),
- ''.join(found['something'].iter_bytes()))
+ self.assertEqual(_b('').join(expected['something'].iter_bytes()),
+ _b('').join(found['something'].iter_bytes()))
diff --git a/python/subunit/tests/test_subunit_filter.py b/python/subunit/tests/test_subunit_filter.py
index 682f726..fb6ffcd 100644
--- a/python/subunit/tests/test_subunit_filter.py
+++ b/python/subunit/tests/test_subunit_filter.py
@@ -21,7 +21,7 @@ from subunit import iso8601
import unittest
from testtools import TestCase
-from testtools.compat import StringIO
+from testtools.compat import _b, BytesIO, StringIO
from testtools.testresult.doubles import ExtendedTestResult
import subunit
@@ -35,7 +35,7 @@ class TestTestResultFilter(TestCase):
# is an easy pithy way of getting a series of test objects to call into
# the TestResult, and as TestResultFilter is intended for use with subunit
# also has the benefit of detecting any interface skew issues.
- example_subunit_stream = """\
+ example_subunit_stream = _b("""\
tags: global
test passed
success passed
@@ -50,7 +50,7 @@ test skipped
skip skipped
test todo
xfail todo
-"""
+""")
def run_tests(self, result_filter, input_stream=None):
"""Run tests through the given filter.
@@ -61,7 +61,7 @@ xfail todo
"""
if input_stream is None:
input_stream = self.example_subunit_stream
- test = subunit.ProtocolTestCase(StringIO(input_stream))
+ test = subunit.ProtocolTestCase(BytesIO(input_stream))
test.run(result_filter)
def test_default(self):
@@ -139,13 +139,13 @@ xfail todo
date_a = datetime(year=2000, month=1, day=1, tzinfo=iso8601.UTC)
date_b = datetime(year=2000, month=1, day=2, tzinfo=iso8601.UTC)
date_c = datetime(year=2000, month=1, day=3, tzinfo=iso8601.UTC)
- subunit_stream = '\n'.join([
+ subunit_stream = _b('\n'.join([
"time: %s",
"test: foo",
"time: %s",
"error: foo",
"time: %s",
- ""]) % (date_a, date_b, date_c)
+ ""]) % (date_a, date_b, date_c))
result = ExtendedTestResult()
result_filter = TestResultFilter(result)
self.run_tests(result_filter, subunit_stream)
diff --git a/python/subunit/tests/test_test_protocol.py b/python/subunit/tests/test_test_protocol.py
index 7778fcc..7ec7758 100644
--- a/python/subunit/tests/test_test_protocol.py
+++ b/python/subunit/tests/test_test_protocol.py
@@ -28,7 +28,7 @@ from testtools.tests.helpers import (
)
import subunit
-from subunit import _remote_exception_str
+from subunit import _remote_exception_str, _remote_exception_str_chunked
import subunit.iso8601 as iso8601
@@ -994,11 +994,11 @@ class TestIsolatedTestSuite(unittest.TestCase):
class TestTestProtocolClient(unittest.TestCase):
def setUp(self):
- self.io = StringIO()
+ self.io = BytesIO()
self.protocol = subunit.TestProtocolClient(self.io)
self.test = TestTestProtocolClient("test_start_test")
self.sample_details = {'something':Content(
- ContentType('text', 'plain'), lambda:['serialised\nform'])}
+ ContentType('text', 'plain'), lambda:[_b('serialised\nform')])}
self.sample_tb_details = dict(self.sample_details)
self.sample_tb_details['traceback'] = TracebackContent(
subunit.RemoteError(_u("boo qux")), self.test)
@@ -1006,27 +1006,27 @@ class TestTestProtocolClient(unittest.TestCase):
def test_start_test(self):
"""Test startTest on a TestProtocolClient."""
self.protocol.startTest(self.test)
- self.assertEqual(self.io.getvalue(), "test: %s\n" % self.test.id())
+ self.assertEqual(self.io.getvalue(), _b("test: %s\n" % self.test.id()))
def test_stop_test(self):
# stopTest doesn't output anything.
self.protocol.stopTest(self.test)
- self.assertEqual(self.io.getvalue(), "")
+ self.assertEqual(self.io.getvalue(), _b(""))
def test_add_success(self):
"""Test addSuccess on a TestProtocolClient."""
self.protocol.addSuccess(self.test)
self.assertEqual(
- self.io.getvalue(), "successful: %s\n" % self.test.id())
+ self.io.getvalue(), _b("successful: %s\n" % self.test.id()))
def test_add_success_details(self):
"""Test addSuccess on a TestProtocolClient with details."""
self.protocol.addSuccess(self.test, details=self.sample_details)
self.assertEqual(
- self.io.getvalue(), "successful: %s [ multipart\n"
+ self.io.getvalue(), _b("successful: %s [ multipart\n"
"Content-Type: text/plain\n"
"something\n"
- "F\r\nserialised\nform0\r\n]\n" % self.test.id())
+ "F\r\nserialised\nform0\r\n]\n" % self.test.id()))
def test_add_failure(self):
"""Test addFailure on a TestProtocolClient."""
@@ -1034,8 +1034,8 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, subunit.RemoteError(_u("boo qux")))
self.assertEqual(
self.io.getvalue(),
- ('failure: %s [\n' + _remote_exception_str + ': boo qux\n]\n')
- % self.test.id())
+ _b(('failure: %s [\n' + _remote_exception_str + ': boo qux\n]\n')
+ % self.test.id()))
def test_add_failure_details(self):
"""Test addFailure on a TestProtocolClient with details."""
@@ -1043,14 +1043,13 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, details=self.sample_tb_details)
self.assertEqual(
self.io.getvalue(),
- ("failure: %s [ multipart\n"
+ _b(("failure: %s [ multipart\n"
"Content-Type: text/plain\n"
"something\n"
"F\r\nserialised\nform0\r\n"
"Content-Type: text/x-traceback;charset=utf8,language=python\n"
- "traceback\n"
- "1A\r\n" + _remote_exception_str + ": boo qux\n0\r\n"
- "]\n") % self.test.id())
+ "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n"
+ "]\n") % self.test.id()))
def test_add_error(self):
"""Test stopTest on a TestProtocolClient."""
@@ -1058,9 +1057,9 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, subunit.RemoteError(_u("phwoar crikey")))
self.assertEqual(
self.io.getvalue(),
- ('error: %s [\n' +
+ _b(('error: %s [\n' +
_remote_exception_str + ": phwoar crikey\n"
- "]\n") % self.test.id())
+ "]\n") % self.test.id()))
def test_add_error_details(self):
"""Test stopTest on a TestProtocolClient with details."""
@@ -1068,14 +1067,13 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, details=self.sample_tb_details)
self.assertEqual(
self.io.getvalue(),
- ("error: %s [ multipart\n"
+ _b(("error: %s [ multipart\n"
"Content-Type: text/plain\n"
"something\n"
"F\r\nserialised\nform0\r\n"
"Content-Type: text/x-traceback;charset=utf8,language=python\n"
- "traceback\n"
- "1A\r\n" + _remote_exception_str + ": boo qux\n0\r\n"
- "]\n") % self.test.id())
+ "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n"
+ "]\n") % self.test.id()))
def test_add_expected_failure(self):
"""Test addExpectedFailure on a TestProtocolClient."""
@@ -1083,9 +1081,9 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, subunit.RemoteError(_u("phwoar crikey")))
self.assertEqual(
self.io.getvalue(),
- ('xfail: %s [\n' +
+ _b(('xfail: %s [\n' +
_remote_exception_str + ": phwoar crikey\n"
- "]\n") % self.test.id())
+ "]\n") % self.test.id()))
def test_add_expected_failure_details(self):
"""Test addExpectedFailure on a TestProtocolClient with details."""
@@ -1093,14 +1091,14 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, details=self.sample_tb_details)
self.assertEqual(
self.io.getvalue(),
- ("xfail: %s [ multipart\n"
+ _b(("xfail: %s [ multipart\n"
"Content-Type: text/plain\n"
"something\n"
"F\r\nserialised\nform0\r\n"
"Content-Type: text/x-traceback;charset=utf8,language=python\n"
- "traceback\n"
- "1A\r\n"+ _remote_exception_str + ": boo qux\n0\r\n"
- "]\n") % self.test.id())
+ "traceback\n" + _remote_exception_str_chunked + ": boo qux\n0\r\n"
+ "]\n") % self.test.id()))
+
def test_add_skip(self):
"""Test addSkip on a TestProtocolClient."""
@@ -1108,64 +1106,63 @@ class TestTestProtocolClient(unittest.TestCase):
self.test, "Has it really?")
self.assertEqual(
self.io.getvalue(),
- 'skip: %s [\nHas it really?\n]\n' % self.test.id())
+ _b('skip: %s [\nHas it really?\n]\n' % self.test.id()))
def test_add_skip_details(self):
"""Test addSkip on a TestProtocolClient with details."""
details = {'reason':Content(
- ContentType('text', 'plain'), lambda:['Has it really?'])}
- self.protocol.addSkip(
- self.test, details=details)
+ ContentType('text', 'plain'), lambda:[_b('Has it really?')])}
+ self.protocol.addSkip(self.test, details=details)
self.assertEqual(
self.io.getvalue(),
- "skip: %s [ multipart\n"
+ _b("skip: %s [ multipart\n"
"Content-Type: text/plain\n"
"reason\n"
"E\r\nHas it really?0\r\n"
- "]\n" % self.test.id())
+ "]\n" % self.test.id()))
def test_progress_set(self):
self.protocol.progress(23, subunit.PROGRESS_SET)
- self.assertEqual(self.io.getvalue(), 'progress: 23\n')
+ self.assertEqual(self.io.getvalue(), _b('progress: 23\n'))
def test_progress_neg_cur(self):
self.protocol.progress(-23, subunit.PROGRESS_CUR)
- self.assertEqual(self.io.getvalue(), 'progress: -23\n')
+ self.assertEqual(self.io.getvalue(), _b('progress: -23\n'))
def test_progress_pos_cur(self):
self.protocol.progress(23, subunit.PROGRESS_CUR)
- self.assertEqual(self.io.getvalue(), 'progress: +23\n')
+ self.assertEqual(self.io.getvalue(), _b('progress: +23\n'))
def test_progress_pop(self):
self.protocol.progress(1234, subunit.PROGRESS_POP)
- self.assertEqual(self.io.getvalue(), 'progress: pop\n')
+ self.assertEqual(self.io.getvalue(), _b('progress: pop\n'))
def test_progress_push(self):
self.protocol.progress(1234, subunit.PROGRESS_PUSH)
- self.assertEqual(self.io.getvalue(), 'progress: push\n')
+ self.assertEqual(self.io.getvalue(), _b('progress: push\n'))
def test_time(self):
# Calling time() outputs a time signal immediately.
self.protocol.time(
datetime.datetime(2009,10,11,12,13,14,15, iso8601.Utc()))
self.assertEqual(
- "time: 2009-10-11 12:13:14.000015Z\n",
+ _b("time: 2009-10-11 12:13:14.000015Z\n"),
self.io.getvalue())
def test_add_unexpected_success(self):
"""Test addUnexpectedSuccess on a TestProtocolClient."""
self.protocol.addUnexpectedSuccess(self.test)
self.assertEqual(
- self.io.getvalue(), "successful: %s\n" % self.test.id())
+ self.io.getvalue(), _b("successful: %s\n" % self.test.id()))
def test_add_unexpected_success_details(self):
"""Test addUnexpectedSuccess on a TestProtocolClient with details."""
self.protocol.addUnexpectedSuccess(self.test, details=self.sample_details)
self.assertEqual(
- self.io.getvalue(), "successful: %s [ multipart\n"
+ self.io.getvalue(), _b("successful: %s [ multipart\n"
"Content-Type: text/plain\n"
"something\n"
- "F\r\nserialised\nform0\r\n]\n" % self.test.id())
+ "F\r\nserialised\nform0\r\n]\n" % self.test.id()))
def test_suite():