diff options
-rw-r--r-- | kafka/consumer/fetcher.py | 34 | ||||
-rw-r--r-- | kafka/protocol/message.py | 27 |
2 files changed, 40 insertions, 21 deletions
diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 8a48575..fc03d7a 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -299,15 +299,7 @@ class Fetcher(object): " and update consumed position to %s", tp, next_offset) self._subscriptions.assignment[tp].consumed = next_offset - # TODO: handle compressed messages - for offset, size, msg in messages: - if msg.attributes: - raise Errors.KafkaError('Compressed messages not supported yet') - elif self.config['check_crcs'] and not msg.validate_crc(): - raise Errors.InvalidMessageError(msg) - - key, value = self._deserialize(msg) - record = ConsumerRecord(tp.topic, tp.partition, offset, key, value) + for record in self._unpack_message_set(tp, messages): drained[tp].append(record) else: # these records aren't next in line based on the last consumed @@ -316,6 +308,17 @@ class Fetcher(object): tp, fetch_offset) return dict(drained) + def _unpack_message_set(self, tp, messages): + for offset, size, msg in messages: + if self.config['check_crcs'] and not msg.validate_crc(): + raise Errors.InvalidMessageError(msg) + elif msg.is_compressed(): + for record in self._unpack_message_set(tp, msg.decompress()): + yield record + else: + key, value = self._deserialize(msg) + yield ConsumerRecord(tp.topic, tp.partition, offset, key, value) + def __iter__(self): """Iterate over fetched_records""" if self._subscriptions.needs_partition_assignment: @@ -349,16 +352,9 @@ class Fetcher(object): self._subscriptions.assignment[tp].fetched = consumed elif fetch_offset == consumed: - # TODO: handle compressed messages - for offset, size, msg in messages: - if msg.attributes: - raise Errors.KafkaError('Compressed messages not supported yet') - elif self.config['check_crcs'] and not msg.validate_crc(): - raise Errors.InvalidMessageError(msg) - - self._subscriptions.assignment[tp].consumed = offset + 1 - key, value = self._deserialize(msg) - yield ConsumerRecord(tp.topic, tp.partition, offset, key, value) + for msg in self._unpack_message_set(tp, messages): + self._subscriptions.assignment[tp].consumed = msg.offset + 1 + yield msg else: # these records aren't next in line based on the last consumed # position, ignore them they must be from an obsolete request diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py index f6cbb33..f893912 100644 --- a/kafka/protocol/message.py +++ b/kafka/protocol/message.py @@ -1,5 +1,6 @@ import io +from ..codec import gzip_decode, snappy_decode from . import pickle from .struct import Struct from .types import ( @@ -16,6 +17,9 @@ class Message(Struct): ('key', Bytes), ('value', Bytes) ) + CODEC_MASK = 0x03 + CODEC_GZIP = 0x01 + CODEC_SNAPPY = 0x02 def __init__(self, value, key=None, magic=0, attributes=0, crc=0): self.crc = crc @@ -49,6 +53,19 @@ class Message(Struct): return True return False + def is_compressed(self): + return self.attributes & self.CODEC_MASK != 0 + + def decompress(self): + codec = self.attributes & self.CODEC_MASK + assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY) + if codec == self.CODEC_GZIP: + raw_bytes = gzip_decode(self.value) + else: + raw_bytes = snappy_decode(self.value) + + return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) + class PartialMessage(bytes): def __repr__(self): @@ -81,8 +98,14 @@ class MessageSet(AbstractType): return Int32.encode(len(encoded)) + encoded @classmethod - def decode(cls, data): - bytes_to_read = Int32.decode(data) + def decode(cls, data, bytes_to_read=None): + """Compressed messages should pass in bytes_to_read (via message size) + otherwise, we decode from data as Int32 + """ + if isinstance(data, bytes): + data = io.BytesIO(data) + if bytes_to_read is None: + bytes_to_read = Int32.decode(data) items = [] # We need at least 8 + 4 + 14 bytes to read offset + message size + message |