summaryrefslogtreecommitdiff
path: root/python/subunit
diff options
context:
space:
mode:
Diffstat (limited to 'python/subunit')
-rw-r--r--python/subunit/v2.py43
1 files changed, 26 insertions, 17 deletions
diff --git a/python/subunit/v2.py b/python/subunit/v2.py
index c2c63f6..e8a31d6 100644
--- a/python/subunit/v2.py
+++ b/python/subunit/v2.py
@@ -72,13 +72,32 @@ def has_nul(buffer_or_bytes):
return NUL_ELEMENT in buffer_or_bytes
+def read_exactly(stream, size):
+ """Read exactly size bytes from stream.
+
+ :param stream: A file like object to read bytes from. Must support
+ read(<count>) and return bytes.
+ :param size: The number of bytes to retrieve.
+ """
+ data = b''
+ remaining = size
+ while remaining:
+ read = stream.read(remaining)
+ if len(read) == 0:
+ raise ParseError('Short read - got %d bytes, wanted %d bytes' % (
+ len(data), size))
+ data += read
+ remaining -= len(read)
+ return data
+
+
class ParseError(Exception):
"""Used to pass error messages within the parser."""
class StreamResultToBytes(object):
"""Convert StreamResult API calls to bytes.
-
+
The StreamResult API is defined by testtools.StreamResult.
"""
@@ -276,7 +295,7 @@ class ByteStreamToStreamResult(object):
def run(self, result):
"""Parse source and emit events to result.
-
+
This is a blocking call: it will run until EOF is detected on source.
"""
self.codec.reset()
@@ -406,21 +425,12 @@ class ByteStreamToStreamResult(object):
def _parse(self, packet, result):
# 2 bytes flags, at most 3 bytes length.
- packet.append(self.source.read(5))
- if len(packet[-1]) != 5:
- raise ParseError(
- 'Short read - got %d bytes, wanted 5' % len(packet[-1]))
-
- flag_bytes = packet[-1][:2]
- flags = struct.unpack(FMT_16, flag_bytes)[0]
- length, consumed = self._parse_varint(
- packet[-1], 2, max_3_bytes=True)
- remainder = self.source.read(length - 6)
- if len(remainder) != length - 6:
- raise ParseError(
- 'Short read - got %d bytes, wanted %d bytes' % (
- len(remainder), length - 6))
+ header = read_exactly(self.source, 5)
+ packet.append(header)
+ flags = struct.unpack(FMT_16, header[:2])[0]
+ length, consumed = self._parse_varint(header, 2, max_3_bytes=True)
+ remainder = read_exactly(self.source, length - 6)
if consumed != 3:
# Avoid having to parse torn values
packet[-1] += remainder
@@ -533,4 +543,3 @@ class ByteStreamToStreamResult(object):
return utf8, length+pos
except UnicodeDecodeError:
raise ParseError('UTF8 string at offset %d is not UTF8' % (pos-2,))
-