summaryrefslogtreecommitdiff
path: root/kafka/protocol/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'kafka/protocol/message.py')
-rw-r--r--kafka/protocol/message.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
new file mode 100644
index 0000000..2648e24
--- /dev/null
+++ b/kafka/protocol/message.py
@@ -0,0 +1,144 @@
+import io
+
+from ..codec import gzip_decode, snappy_decode
+from . import pickle
+from .struct import Struct
+from .types import (
+ Int8, Int32, Int64, Bytes, Schema, AbstractType
+)
+from ..util import crc32
+
+
+class Message(Struct):
+ SCHEMA = Schema(
+ ('crc', Int32),
+ ('magic', Int8),
+ ('attributes', Int8),
+ ('key', Bytes),
+ ('value', Bytes)
+ )
+ CODEC_MASK = 0x03
+ CODEC_GZIP = 0x01
+ CODEC_SNAPPY = 0x02
+
+ def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
+ assert value is None or isinstance(value, bytes), 'value must be bytes'
+ assert key is None or isinstance(key, bytes), 'key must be bytes'
+ self.crc = crc
+ self.magic = magic
+ self.attributes = attributes
+ self.key = key
+ self.value = value
+ self.encode = self._encode_self
+
+ def _encode_self(self, recalc_crc=True):
+ message = Message.SCHEMA.encode(
+ (self.crc, self.magic, self.attributes, self.key, self.value)
+ )
+ if not recalc_crc:
+ return message
+ self.crc = crc32(message[4:])
+ return self.SCHEMA.fields[0].encode(self.crc) + message[4:]
+
+ @classmethod
+ def decode(cls, data):
+ if isinstance(data, bytes):
+ data = io.BytesIO(data)
+ fields = [field.decode(data) for field in cls.SCHEMA.fields]
+ return cls(fields[4], key=fields[3],
+ magic=fields[1], attributes=fields[2], crc=fields[0])
+
+ def validate_crc(self):
+ raw_msg = self._encode_self(recalc_crc=False)
+ crc = crc32(raw_msg[4:])
+ if crc == self.crc:
+ return True
+ return False
+
+ def is_compressed(self):
+ return self.attributes & self.CODEC_MASK != 0
+
+ def decompress(self):
+ codec = self.attributes & self.CODEC_MASK
+ assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY)
+ if codec == self.CODEC_GZIP:
+ raw_bytes = gzip_decode(self.value)
+ else:
+ raw_bytes = snappy_decode(self.value)
+
+ return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes))
+
+ def __hash__(self):
+ return hash(self._encode_self(recalc_crc=False))
+
+
+class PartialMessage(bytes):
+ def __repr__(self):
+ return 'PartialMessage(%s)' % self
+
+
+class MessageSet(AbstractType):
+ ITEM = Schema(
+ ('offset', Int64),
+ ('message_size', Int32),
+ ('message', Message.SCHEMA)
+ )
+
+ @classmethod
+ def encode(cls, items, size=True, recalc_message_size=True):
+ encoded_values = []
+ for (offset, message_size, message) in items:
+ if isinstance(message, Message):
+ encoded_message = message.encode()
+ else:
+ encoded_message = cls.ITEM.fields[2].encode(message)
+ if recalc_message_size:
+ message_size = len(encoded_message)
+ encoded_values.append(cls.ITEM.fields[0].encode(offset))
+ encoded_values.append(cls.ITEM.fields[1].encode(message_size))
+ encoded_values.append(encoded_message)
+ encoded = b''.join(encoded_values)
+ if not size:
+ return encoded
+ return Int32.encode(len(encoded)) + encoded
+
+ @classmethod
+ def decode(cls, data, bytes_to_read=None):
+ """Compressed messages should pass in bytes_to_read (via message size)
+ otherwise, we decode from data as Int32
+ """
+ if isinstance(data, bytes):
+ data = io.BytesIO(data)
+ if bytes_to_read is None:
+ bytes_to_read = Int32.decode(data)
+ items = []
+
+ # We need at least 8 + 4 + 14 bytes to read offset + message size + message
+ # (14 bytes is a message w/ null key and null value)
+ while bytes_to_read >= 26:
+ offset = Int64.decode(data)
+ bytes_to_read -= 8
+
+ message_size = Int32.decode(data)
+ bytes_to_read -= 4
+
+ # if FetchRequest max_bytes is smaller than the available message set
+ # the server returns partial data for the final message
+ if message_size > bytes_to_read:
+ break
+
+ message = Message.decode(data)
+ bytes_to_read -= message_size
+
+ items.append((offset, message_size, message))
+
+ # If any bytes are left over, clear them from the buffer
+ # and append a PartialMessage to signal that max_bytes may be too small
+ if bytes_to_read:
+ items.append((None, None, PartialMessage(data.read(bytes_to_read))))
+
+ return items
+
+ @classmethod
+ def repr(cls, messages):
+ return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']'