summaryrefslogtreecommitdiff
path: root/kafka/protocol
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2017-08-15 13:00:02 -0700
committerGitHub <noreply@github.com>2017-08-15 13:00:02 -0700
commitba7afd9bc9362055ec0bedcf53eb6f8909dc22d2 (patch)
treef68b4dc2653df1e379da7b497e0fa76a19d6c5a5 /kafka/protocol
parentcbc6fdc4b973a6a94953c9ce9c33e54e415e45bf (diff)
downloadkafka-python-ba7afd9bc9362055ec0bedcf53eb6f8909dc22d2.tar.gz
BrokerConnection receive bytes pipe (#1032)
Diffstat (limited to 'kafka/protocol')
-rw-r--r--kafka/protocol/frame.py30
-rw-r--r--kafka/protocol/message.py7
2 files changed, 34 insertions, 3 deletions
diff --git a/kafka/protocol/frame.py b/kafka/protocol/frame.py
new file mode 100644
index 0000000..7b4a32b
--- /dev/null
+++ b/kafka/protocol/frame.py
@@ -0,0 +1,30 @@
+class KafkaBytes(bytearray):
+ def __init__(self, size):
+ super(KafkaBytes, self).__init__(size)
+ self._idx = 0
+
+ def read(self, nbytes=None):
+ if nbytes is None:
+ nbytes = len(self) - self._idx
+ start = self._idx
+ self._idx += nbytes
+ if self._idx > len(self):
+ self._idx = len(self)
+ return bytes(self[start:self._idx])
+
+ def write(self, data):
+ start = self._idx
+ self._idx += len(data)
+ self[start:self._idx] = data
+
+ def seek(self, idx):
+ self._idx = idx
+
+ def tell(self):
+ return self._idx
+
+ def __str__(self):
+ return 'KafkaBytes(%d)' % len(self)
+
+ def __repr__(self):
+ return str(self)
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
index efdf4fc..70d5b36 100644
--- a/kafka/protocol/message.py
+++ b/kafka/protocol/message.py
@@ -6,6 +6,7 @@ import time
from ..codec import (has_gzip, has_snappy, has_lz4,
gzip_decode, snappy_decode,
lz4_decode, lz4_decode_old_kafka)
+from .frame import KafkaBytes
from .struct import Struct
from .types import (
Int8, Int32, Int64, Bytes, Schema, AbstractType
@@ -155,10 +156,10 @@ class MessageSet(AbstractType):
@classmethod
def encode(cls, items):
# RecordAccumulator encodes messagesets internally
- if isinstance(items, io.BytesIO):
+ if isinstance(items, (io.BytesIO, KafkaBytes)):
size = Int32.decode(items)
# rewind and return all the bytes
- items.seek(-4, 1)
+ items.seek(items.tell() - 4)
return items.read(size + 4)
encoded_values = []
@@ -198,7 +199,7 @@ class MessageSet(AbstractType):
@classmethod
def repr(cls, messages):
- if isinstance(messages, io.BytesIO):
+ if isinstance(messages, (KafkaBytes, io.BytesIO)):
offset = messages.tell()
decoded = cls.decode(messages)
messages.seek(offset)