diff options
Diffstat (limited to 'kafka/protocol/message.py')
-rw-r--r-- | kafka/protocol/message.py | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py new file mode 100644 index 0000000..2648e24 --- /dev/null +++ b/kafka/protocol/message.py @@ -0,0 +1,144 @@ +import io + +from ..codec import gzip_decode, snappy_decode +from . import pickle +from .struct import Struct +from .types import ( + Int8, Int32, Int64, Bytes, Schema, AbstractType +) +from ..util import crc32 + + +class Message(Struct): + SCHEMA = Schema( + ('crc', Int32), + ('magic', Int8), + ('attributes', Int8), + ('key', Bytes), + ('value', Bytes) + ) + CODEC_MASK = 0x03 + CODEC_GZIP = 0x01 + CODEC_SNAPPY = 0x02 + + def __init__(self, value, key=None, magic=0, attributes=0, crc=0): + assert value is None or isinstance(value, bytes), 'value must be bytes' + assert key is None or isinstance(key, bytes), 'key must be bytes' + self.crc = crc + self.magic = magic + self.attributes = attributes + self.key = key + self.value = value + self.encode = self._encode_self + + def _encode_self(self, recalc_crc=True): + message = Message.SCHEMA.encode( + (self.crc, self.magic, self.attributes, self.key, self.value) + ) + if not recalc_crc: + return message + self.crc = crc32(message[4:]) + return self.SCHEMA.fields[0].encode(self.crc) + message[4:] + + @classmethod + def decode(cls, data): + if isinstance(data, bytes): + data = io.BytesIO(data) + fields = [field.decode(data) for field in cls.SCHEMA.fields] + return cls(fields[4], key=fields[3], + magic=fields[1], attributes=fields[2], crc=fields[0]) + + def validate_crc(self): + raw_msg = self._encode_self(recalc_crc=False) + crc = crc32(raw_msg[4:]) + if crc == self.crc: + return True + return False + + def is_compressed(self): + return self.attributes & self.CODEC_MASK != 0 + + def decompress(self): + codec = self.attributes & self.CODEC_MASK + assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY) + if codec == self.CODEC_GZIP: + raw_bytes = gzip_decode(self.value) + else: + raw_bytes = snappy_decode(self.value) + + return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) + + def __hash__(self): + return hash(self._encode_self(recalc_crc=False)) + + +class PartialMessage(bytes): + def __repr__(self): + return 'PartialMessage(%s)' % self + + +class MessageSet(AbstractType): + ITEM = Schema( + ('offset', Int64), + ('message_size', Int32), + ('message', Message.SCHEMA) + ) + + @classmethod + def encode(cls, items, size=True, recalc_message_size=True): + encoded_values = [] + for (offset, message_size, message) in items: + if isinstance(message, Message): + encoded_message = message.encode() + else: + encoded_message = cls.ITEM.fields[2].encode(message) + if recalc_message_size: + message_size = len(encoded_message) + encoded_values.append(cls.ITEM.fields[0].encode(offset)) + encoded_values.append(cls.ITEM.fields[1].encode(message_size)) + encoded_values.append(encoded_message) + encoded = b''.join(encoded_values) + if not size: + return encoded + return Int32.encode(len(encoded)) + encoded + + @classmethod + def decode(cls, data, bytes_to_read=None): + """Compressed messages should pass in bytes_to_read (via message size) + otherwise, we decode from data as Int32 + """ + if isinstance(data, bytes): + data = io.BytesIO(data) + if bytes_to_read is None: + bytes_to_read = Int32.decode(data) + items = [] + + # We need at least 8 + 4 + 14 bytes to read offset + message size + message + # (14 bytes is a message w/ null key and null value) + while bytes_to_read >= 26: + offset = Int64.decode(data) + bytes_to_read -= 8 + + message_size = Int32.decode(data) + bytes_to_read -= 4 + + # if FetchRequest max_bytes is smaller than the available message set + # the server returns partial data for the final message + if message_size > bytes_to_read: + break + + message = Message.decode(data) + bytes_to_read -= message_size + + items.append((offset, message_size, message)) + + # If any bytes are left over, clear them from the buffer + # and append a PartialMessage to signal that max_bytes may be too small + if bytes_to_read: + items.append((None, None, PartialMessage(data.read(bytes_to_read)))) + + return items + + @classmethod + def repr(cls, messages): + return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']' |