summaryrefslogtreecommitdiff
path: root/kafka/protocol/message.py
blob: cd5d2743926863f6ee924203101da4a49164db70 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import io

from .struct import Struct
from .types import (
    Int8, Int32, Int64, Bytes, Schema, AbstractType
)
from ..util import crc32


class Message(Struct):
    SCHEMA = Schema(
        ('crc', Int32),
        ('magic', Int8),
        ('attributes', Int8),
        ('key', Bytes),
        ('value', Bytes)
    )

    def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
        self.crc = crc
        self.magic = magic
        self.attributes = attributes
        self.key = key
        self.value = value
        self.encode = self._encode_self

    def _encode_self(self, recalc_crc=True):
        message = Message.SCHEMA.encode(
          (self.crc, self.magic, self.attributes, self.key, self.value)
        )
        if not recalc_crc:
            return message
        self.crc = crc32(message[4:])
        return self.SCHEMA.fields[0].encode(self.crc) + message[4:]

    @classmethod
    def decode(cls, data):
        if isinstance(data, bytes):
            data = io.BytesIO(data)
        fields = [field.decode(data) for field in cls.SCHEMA.fields]
        return cls(fields[4], key=fields[3],
                   magic=fields[1], attributes=fields[2], crc=fields[0])


class PartialMessage(bytes):
    def __repr__(self):
        return 'PartialMessage(%s)' % self


class MessageSet(AbstractType):
    ITEM = Schema(
        ('offset', Int64),
        ('message_size', Int32),
        ('message', Message.SCHEMA)
    )

    @classmethod
    def encode(cls, items, size=True, recalc_message_size=True):
        encoded_values = []
        for (offset, message_size, message) in items:
            if isinstance(message, Message):
                encoded_message = message.encode()
            else:
                encoded_message = cls.ITEM.fields[2].encode(message)
            if recalc_message_size:
                message_size = len(encoded_message)
            encoded_values.append(cls.ITEM.fields[0].encode(offset))
            encoded_values.append(cls.ITEM.fields[1].encode(message_size))
            encoded_values.append(encoded_message)
        encoded = b''.join(encoded_values)
        if not size:
            return encoded
        return Int32.encode(len(encoded)) + encoded

    @classmethod
    def decode(cls, data):
        bytes_to_read = Int32.decode(data)
        items = []

        # We need at least 8 + 4 + 14 bytes to read offset + message size + message
        # (14 bytes is a message w/ null key and null value)
        while bytes_to_read >= 26:
            offset = Int64.decode(data)
            bytes_to_read -= 8

            message_size = Int32.decode(data)
            bytes_to_read -= 4

            # if FetchRequest max_bytes is smaller than the available message set
            # the server returns partial data for the final message
            if message_size > bytes_to_read:
                break

            message = Message.decode(data)
            bytes_to_read -= message_size

            items.append((offset, message_size, message))

        # If any bytes are left over, clear them from the buffer
        # and append a PartialMessage to signal that max_bytes may be too small
        if bytes_to_read:
            items.append((None, None, PartialMessage(data.read(bytes_to_read))))

        return items

    @classmethod
    def repr(cls, messages):
        return '[' + ', '.join([cls.ITEM.repr(m) for m in messages]) + ']'