summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@rd.io>2015-11-28 19:41:06 +0800
committerZack Dever <zack.dever@rd.io>2015-12-04 11:25:39 -0800
commita85e09df89a43de5b659a0fa4ed35bec37c60e04 (patch)
treea539af32fe502006c1f35b96d8ae36225292f7a5
parente24a4d5f5252d6f97ac586e328b95779ef83f4b6 (diff)
downloadkafka-python-a85e09df89a43de5b659a0fa4ed35bec37c60e04.tar.gz
Rework protocol type definition: AbstractType, Schema, Struct
-rw-r--r--kafka/protocol/abstract.py13
-rw-r--r--kafka/protocol/api.py309
-rw-r--r--kafka/protocol/commit.py111
-rw-r--r--kafka/protocol/fetch.py30
-rw-r--r--kafka/protocol/message.py67
-rw-r--r--kafka/protocol/metadata.py28
-rw-r--r--kafka/protocol/offset.py32
-rw-r--r--kafka/protocol/produce.py81
-rw-r--r--kafka/protocol/struct.py52
-rw-r--r--kafka/protocol/types.py109
10 files changed, 461 insertions, 371 deletions
diff --git a/kafka/protocol/abstract.py b/kafka/protocol/abstract.py
new file mode 100644
index 0000000..9c53c8c
--- /dev/null
+++ b/kafka/protocol/abstract.py
@@ -0,0 +1,13 @@
+import abc
+
+
+class AbstractType(object):
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def encode(cls, value):
+ pass
+
+ @abc.abstractmethod
+ def decode(cls, data):
+ pass
diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py
index 8ea820b..0c23437 100644
--- a/kafka/protocol/api.py
+++ b/kafka/protocol/api.py
@@ -1,301 +1,16 @@
-import struct
+from .struct import Struct
+from .types import Int16, Int32, String, Schema
-from .types import (
- Int8, Int16, Int32, Int64, Bytes, String, Array
-)
-from ..util import crc32
+class RequestHeader(Struct):
+ SCHEMA = Schema(
+ ('api_key', Int16),
+ ('api_version', Int16),
+ ('correlation_id', Int32),
+ ('client_id', String('utf-8'))
+ )
-class Message(object):
- MAGIC_BYTE = 0
- __slots__ = ('magic', 'attributes', 'key', 'value')
-
- def __init__(self, value, key=None, magic=0, attributes=0):
- self.magic = magic
- self.attributes = attributes
- self.key = key
- self.value = value
-
- def encode(self):
- message = (
- Int8.encode(self.magic) +
- Int8.encode(self.attributes) +
- Bytes.encode(self.key) +
- Bytes.encode(self.value)
+ def __init__(self, request, correlation_id=0, client_id='kafka-python'):
+ super(RequestHeader, self).__init__(
+ request.API_KEY, request.API_VERSION, correlation_id, client_id
)
- return (
- struct.pack('>I', crc32(message)) +
- message
- )
-
-
-class MessageSet(object):
-
- @staticmethod
- def _encode_one(message):
- encoded = message.encode()
- return (Int64.encode(0) + Int32.encode(len(encoded)) + encoded)
-
- @staticmethod
- def encode(messages):
- return b''.join(map(MessageSet._encode_one, messages))
-
-
-class AbstractRequestResponse(object):
- @classmethod
- def encode(cls, message):
- return Int32.encode(len(message)) + message
-
-
-class AbstractRequest(AbstractRequestResponse):
- @classmethod
- def encode(cls, request, correlation_id=0, client_id='kafka-python'):
- request = (Int16.encode(cls.API_KEY) +
- Int16.encode(cls.API_VERSION) +
- Int32.encode(correlation_id) +
- String.encode(client_id) +
- request)
- return super(AbstractRequest, cls).encode(request)
-
-
-class FetchRequest(AbstractRequest):
- API_KEY = 1
- API_VERSION = 0
- __slots__ = ('replica_id', 'max_wait_time', 'min_bytes', 'topic_partition_offsets')
-
- def __init__(self, topic_partition_offsets,
- max_wait_time=-1, min_bytes=0, replica_id=-1):
- """
- topic_partition_offsets is a dict of dicts of (offset, max_bytes) tuples
- {
- "TopicFoo": {
- 0: (1234, 1048576),
- 1: (1324, 1048576)
- }
- }
- """
- self.topic_partition_offsets = topic_partition_offsets
- self.max_wait_time = max_wait_time
- self.min_bytes = min_bytes
- self.replica_id = replica_id
-
- def encode(self):
- request = (
- Int32.encode(self.replica_id) +
- Int32.encode(self.max_wait_time) +
- Int32.encode(self.min_bytes) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([(
- Int32.encode(partition) +
- Int64.encode(offset) +
- Int32.encode(max_bytes)
- ) for partition, (offset, max_bytes) in partitions.iteritems()])
- ) for topic, partitions in self.topic_partition_offsets.iteritems()]))
- return super(FetchRequest, self).encode(request)
-
-
-class OffsetRequest(AbstractRequest):
- API_KEY = 2
- API_VERSION = 0
- __slots__ = ('replica_id', 'topic_partition_times')
-
- def __init__(self, topic_partition_times, replica_id=-1):
- """
- topic_partition_times is a dict of dicts of (time, max_offsets) tuples
- {
- "TopicFoo": {
- 0: (-1, 1),
- 1: (-1, 1)
- }
- }
- """
- self.topic_partition_times = topic_partition_times
- self.replica_id = replica_id
-
- def encode(self):
- request = (
- Int32.encode(self.replica_id) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([(
- Int32.encode(partition) +
- Int64.encode(time) +
- Int32.encode(max_offsets)
- ) for partition, (time, max_offsets) in partitions.iteritems()])
- ) for topic, partitions in self.topic_partition_times.iteritems()]))
- return super(OffsetRequest, self).encode(request)
-
-
-class MetadataRequest(AbstractRequest):
- API_KEY = 3
- API_VERSION = 0
- __slots__ = ('topics')
-
- def __init__(self, *topics):
- self.topics = topics
-
- def encode(self):
- request = Array.encode(map(String.encode, self.topics))
- return super(MetadataRequest, self).encode(request)
-
-
-# Non-user facing control APIs 4-7
-
-
-class OffsetCommitRequestV0(AbstractRequest):
- API_KEY = 8
- API_VERSION = 0
- __slots__ = ('consumer_group_id', 'offsets')
-
- def __init__(self, consumer_group_id, offsets):
- """
- offsets is a dict of dicts of (offset, metadata) tuples
- {
- "TopicFoo": {
- 0: (1234, ""),
- 1: (1243, "")
- }
- }
- """
- self.consumer_group_id = consumer_group_id
- self.offsets = offsets
-
- def encode(self):
- request = (
- String.encode(self.consumer_group_id) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([(
- Int32.encode(partition) +
- Int64.encode(offset) +
- String.encode(metadata)
- ) for partition, (offset, metadata) in partitions.iteritems()])
- ) for topic, partitions in self.offsets.iteritems()]))
- return super(OffsetCommitRequestV0, self).encode(request)
-
-
-class OffsetCommitRequestV1(AbstractRequest):
- API_KEY = 8
- API_VERSION = 1
- __slots__ = ('consumer_group_id', 'consumer_group_generation_id',
- 'consumer_id', 'offsets')
-
- def __init__(self, consumer_group_id, consumer_group_generation_id,
- consumer_id, offsets):
- """
- offsets is a dict of dicts of (offset, timestamp, metadata) tuples
- {
- "TopicFoo": {
- 0: (1234, 1448198827, ""),
- 1: (1243, 1448198827, "")
- }
- }
- """
- self.consumer_group_id = consumer_group_id
- self.consumer_group_generation_id = consumer_group_generation_id
- self.consumer_id = consumer_id
- self.offsets = offsets
-
- def encode(self):
- request = (
- String.encode(self.consumer_group_id) +
- Int32.encode(self.consumer_group_generation_id) +
- String.encode(self.consumer_id) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([(
- Int32.encode(partition) +
- Int64.encode(offset) +
- Int64.encode(timestamp) +
- String.encode(metadata)
- ) for partition, (offset, timestamp, metadata) in partitions.iteritems()])
- ) for topic, partitions in self.offsets.iteritems()]))
- return super(OffsetCommitRequestV1, self).encode(request)
-
-
-class OffsetCommitRequest(AbstractRequest):
- API_KEY = 8
- API_VERSION = 2
- __slots__ = ('consumer_group_id', 'consumer_group_generation_id',
- 'consumer_id', 'retention_time', 'offsets')
-
- def __init__(self, consumer_group_id, consumer_group_generation_id,
- consumer_id, retention_time, offsets):
- """
- offsets is a dict of dicts of (offset, metadata) tuples
- {
- "TopicFoo": {
- 0: (1234, ""),
- 1: (1243, "")
- }
- }
- """
- self.consumer_group_id = consumer_group_id
- self.consumer_group_generation_id = consumer_group_generation_id
- self.consumer_id = consumer_id
- self.retention_time = retention_time
- self.offsets = offsets
-
- def encode(self):
- request = (
- String.encode(self.consumer_group_id) +
- Int32.encode(self.consumer_group_generation_id) +
- String.encode(self.consumer_id) +
- Int64.encode(self.retention_time) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([(
- Int32.encode(partition) +
- Int64.encode(offset) +
- String.encode(metadata)
- ) for partition, (offset, timestamp, metadata) in partitions.iteritems()])
- ) for topic, partitions in self.offsets.iteritems()]))
- return super(OffsetCommitRequest, self).encode(request)
-
-
-class OffsetFetchRequestV0(AbstractRequest):
- API_KEY = 9
- API_VERSION = 0
- __slots__ = ('consumer_group', 'topic_partitions')
-
- def __init__(self, consumer_group, topic_partitions):
- """
- offsets is a dict of lists of partition ints
- {
- "TopicFoo": [0, 1, 2]
- }
- """
- self.consumer_group = consumer_group
- self.topic_partitions = topic_partitions
-
- def encode(self):
- request = (
- String.encode(self.consumer_group) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([Int32.encode(partition) for partition in partitions])
- ) for topic, partitions in self.topic_partitions.iteritems()])
- )
- return super(OffsetFetchRequest, self).encode(request)
-
-
-class OffsetFetchRequest(OffsetFetchRequestV0):
- """Identical to V0, but offsets fetched from kafka storage not zookeeper"""
- API_VERSION = 1
-
-
-class GroupCoordinatorRequest(AbstractRequest):
- API_KEY = 10
- API_VERSION = 0
- __slots__ = ('group_id',)
-
- def __init__(self, group_id):
- self.group_id = group_id
-
- def encode(self):
- request = String.encode(self.group_id)
- return super(GroupCoordinatorRequest, self).encode(request)
-
-
-
diff --git a/kafka/protocol/commit.py b/kafka/protocol/commit.py
new file mode 100644
index 0000000..5ba0227
--- /dev/null
+++ b/kafka/protocol/commit.py
@@ -0,0 +1,111 @@
+from .struct import Struct
+from .types import Array, Int16, Int32, Int64, Schema, String
+
+
+class OffsetCommitRequest_v2(Struct):
+ API_KEY = 8
+ API_VERSION = 2 # added retention_time, dropped timestamp
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8')),
+ ('consumer_group_generation_id', Int32),
+ ('consumer_id', String('utf-8')),
+ ('retention_time', Int64),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('offset', Int64),
+ ('metadata', String('utf-8'))))))
+ )
+
+
+class OffsetCommitRequest_v1(Struct):
+ API_KEY = 8
+ API_VERSION = 1 # Kafka-backed storage
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8')),
+ ('consumer_group_generation_id', Int32),
+ ('consumer_id', String('utf-8')),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('offset', Int64),
+ ('timestamp', Int64),
+ ('metadata', String('utf-8'))))))
+ )
+
+
+class OffsetCommitRequest_v0(Struct):
+ API_KEY = 8
+ API_VERSION = 0 # Zookeeper-backed storage
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8')),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('offset', Int64),
+ ('metadata', String('utf-8'))))))
+ )
+
+
+class OffsetCommitResponse(Struct):
+ SCHEMA = Schema(
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('error_code', Int16)))))
+ )
+
+
+class OffsetFetchRequest_v1(Struct):
+ API_KEY = 9
+ API_VERSION = 1 # kafka-backed storage
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8')),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(Int32))))
+ )
+
+
+class OffsetFetchRequest_v0(Struct):
+ API_KEY = 9
+ API_VERSION = 0 # zookeeper-backed storage
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8')),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(Int32))))
+ )
+
+
+class OffsetFetchResponse(Struct):
+ SCHEMA = Schema(
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('offset', Int64),
+ ('metadata', String('utf-8')),
+ ('error_code', Int16)))))
+ )
+
+
+class GroupCoordinatorRequest(Struct):
+ API_KEY = 10
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('consumer_group', String('utf-8'))
+ )
+
+
+class GroupCoordinatorResponse(Struct):
+ SCHEMA = Schema(
+ ('error_code', Int16),
+ ('coordinator_id', Int32),
+ ('host', String('utf-8')),
+ ('port', Int32)
+ )
diff --git a/kafka/protocol/fetch.py b/kafka/protocol/fetch.py
new file mode 100644
index 0000000..c6d60cc
--- /dev/null
+++ b/kafka/protocol/fetch.py
@@ -0,0 +1,30 @@
+from .message import MessageSet
+from .struct import Struct
+from .types import Array, Int16, Int32, Int64, Schema, String
+
+
+class FetchRequest(Struct):
+ API_KEY = 1
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('replica_id', Int32),
+ ('max_wait_time', Int32),
+ ('min_bytes', Int32),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('offset', Int64),
+ ('max_bytes', Int32)))))
+ )
+
+class FetchResponse(Struct):
+ SCHEMA = Schema(
+ ('topics', Array(
+ ('topics', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('error_code', Int16),
+ ('highwater_offset', Int64),
+ ('message_set', MessageSet)))))
+ )
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
new file mode 100644
index 0000000..26f5ef6
--- /dev/null
+++ b/kafka/protocol/message.py
@@ -0,0 +1,67 @@
+from .struct import Struct
+from .types import (
+ Int8, Int16, Int32, Int64, Bytes, String, Array, Schema, AbstractType
+)
+from ..util import crc32
+
+
+class Message(Struct):
+ SCHEMA = Schema(
+ ('crc', Int32),
+ ('magic', Int8),
+ ('attributes', Int8),
+ ('key', Bytes),
+ ('value', Bytes)
+ )
+
+ def __init__(self, value, key=None, magic=0, attributes=0, crc=0):
+ self.crc = crc
+ self.magic = magic
+ self.attributes = attributes
+ self.key = key
+ self.value = value
+ self.encode = self._encode_self
+
+ def _encode_self(self, recalc_crc=True):
+ message = Message.SCHEMA.encode(
+ (self.crc, self.magic, self.attributes, self.key, self.value)
+ )
+ if not recalc_crc:
+ return message
+ self.crc = crc32(message[4:])
+ return self.SCHEMA.fields[0].encode(self.crc) + message[4:]
+
+
+class MessageSet(AbstractType):
+ ITEM = Schema(
+ ('offset', Int64),
+ ('message_size', Int32),
+ ('message', Message.SCHEMA)
+ )
+
+ @classmethod
+ def encode(cls, items, size=True, recalc_message_size=True):
+ encoded_values = []
+ for (offset, message_size, message) in items:
+ if isinstance(message, Message):
+ encoded_message = message.encode()
+ else:
+ encoded_message = cls.ITEM.fields[2].encode(message)
+ if recalc_message_size:
+ message_size = len(encoded_message)
+ encoded_values.append(cls.ITEM.fields[0].encode(offset))
+ encoded_values.append(cls.ITEM.fields[1].encode(message_size))
+ encoded_values.append(encoded_message)
+ encoded = b''.join(encoded_values)
+ if not size:
+ return encoded
+ return Int32.encode(len(encoded)) + encoded
+
+ @classmethod
+ def decode(cls, data):
+ size = Int32.decode(data)
+ end = data.tell() + size
+ items = []
+ while data.tell() < end:
+ items.append(cls.ITEM.decode(data))
+ return items
diff --git a/kafka/protocol/metadata.py b/kafka/protocol/metadata.py
new file mode 100644
index 0000000..b35e7ef
--- /dev/null
+++ b/kafka/protocol/metadata.py
@@ -0,0 +1,28 @@
+from .struct import Struct
+from .types import Array, Int16, Int32, Schema, String
+
+
+class MetadataRequest(Struct):
+ API_KEY = 3
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('topics', Array(String('utf-8')))
+ )
+
+
+class MetadataResponse(Struct):
+ SCHEMA = Schema(
+ ('brokers', Array(
+ ('node_id', Int32),
+ ('host', String('utf-8')),
+ ('port', Int32))),
+ ('topics', Array(
+ ('error_code', Int16),
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('error_code', Int16),
+ ('partition', Int32),
+ ('leader', Int32),
+ ('replicas', Array(Int32)),
+ ('isr', Array(Int32))))))
+ )
diff --git a/kafka/protocol/offset.py b/kafka/protocol/offset.py
new file mode 100644
index 0000000..942bdbf
--- /dev/null
+++ b/kafka/protocol/offset.py
@@ -0,0 +1,32 @@
+from .struct import Struct
+from .types import Array, Int16, Int32, Int64, Schema, String
+
+
+class OffsetRequest(Struct):
+ API_KEY = 2
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('replica_id', Int32),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('time', Int64),
+ ('max_offsets', Int32)))))
+ )
+ DEFAULTS = {
+ 'replica_id': -1
+ }
+
+
+class OffsetResponse(Struct):
+ API_KEY = 2
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('error_code', Int16),
+ ('offsets', Array(Int64))))))
+ )
diff --git a/kafka/protocol/produce.py b/kafka/protocol/produce.py
index b875397..532a702 100644
--- a/kafka/protocol/produce.py
+++ b/kafka/protocol/produce.py
@@ -1,59 +1,30 @@
-from .api import AbstractRequest, AbstractResponse, MessageSet
-from .types import Int8, Int16, Int32, Int64, Bytes, String, Array
+from .message import MessageSet
+from .struct import Struct
+from .types import Int8, Int16, Int32, Int64, Bytes, String, Array, Schema
-class ProduceRequest(AbstractRequest):
+class ProduceRequest(Struct):
API_KEY = 0
API_VERSION = 0
- __slots__ = ('required_acks', 'timeout', 'topic_partition_messages', 'compression')
-
- def __init__(self, topic_partition_messages,
- required_acks=-1, timeout=1000, compression=None):
- """
- topic_partition_messages is a dict of dicts of lists (of messages)
- {
- "TopicFoo": {
- 0: [
- Message('foo'),
- Message('bar')
- ],
- 1: [
- Message('fizz'),
- Message('buzz')
- ]
- }
- }
- """
- self.required_acks = required_acks
- self.timeout = timeout
- self.topic_partition_messages = topic_partition_messages
- self.compression = compression
-
- @staticmethod
- def _encode_messages(partition, messages, compression):
- message_set = MessageSet.encode(messages)
-
- if compression:
- # compress message_set data and re-encode as single message
- # then wrap single compressed message in a new message_set
- pass
-
- return (Int32.encode(partition) +
- Int32.encode(len(message_set)) +
- message_set)
-
- def encode(self):
- request = (
- Int16.encode(self.required_acks) +
- Int32.encode(self.timeout) +
- Array.encode([(
- String.encode(topic) +
- Array.encode([
- self._encode_messages(partition, messages, self.compression)
- for partition, messages in partitions.iteritems()])
- ) for topic, partitions in self.topic_partition_messages.iteritems()])
- )
- return super(ProduceRequest, self).encode(request)
-
-
-
+ SCHEMA = Schema(
+ ('required_acks', Int16),
+ ('timeout', Int32),
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('messages', MessageSet)))))
+ )
+
+
+class ProduceResponse(Struct):
+ API_KEY = 0
+ API_VERSION = 0
+ SCHEMA = Schema(
+ ('topics', Array(
+ ('topic', String('utf-8')),
+ ('partitions', Array(
+ ('partition', Int32),
+ ('error_code', Int16),
+ ('offset', Int64)))))
+ )
diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py
new file mode 100644
index 0000000..77f5fe7
--- /dev/null
+++ b/kafka/protocol/struct.py
@@ -0,0 +1,52 @@
+from collections import namedtuple
+from io import BytesIO
+
+from .abstract import AbstractType
+from .types import Schema
+
+
+class Struct(AbstractType):
+ SCHEMA = Schema()
+
+ def __init__(self, *args, **kwargs):
+ if len(args) == len(self.SCHEMA.fields):
+ for i, name in enumerate(self.SCHEMA.names):
+ self.__dict__[name] = args[i]
+ elif len(args) > 0:
+ raise ValueError('Args must be empty or mirror schema')
+ else:
+ self.__dict__.update(kwargs)
+
+ # overloading encode() to support both class and instance
+ self.encode = self._encode_self
+
+ @classmethod
+ def encode(cls, item):
+ bits = []
+ for i, field in enumerate(cls.SCHEMA.fields):
+ bits.append(field.encode(item[i]))
+ return b''.join(bits)
+
+ def _encode_self(self):
+ return self.SCHEMA.encode(
+ [self.__dict__[name] for name in self.SCHEMA.names]
+ )
+
+ @classmethod
+ def decode(cls, data):
+ if isinstance(data, bytes):
+ data = BytesIO(data)
+ return cls(*[field.decode(data) for field in cls.SCHEMA.fields])
+
+ def __repr__(self):
+ key_vals =['%s=%r' % (name, self.__dict__[name])
+ for name in self.SCHEMA.names]
+ return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')'
+
+"""
+class MetaStruct(type):
+ def __new__(cls, clsname, bases, dct):
+ nt = namedtuple(clsname, [name for (name, _) in dct['SCHEMA']])
+ bases = tuple([Struct, nt] + list(bases))
+ return super(MetaStruct, cls).__new__(cls, clsname, bases, dct)
+"""
diff --git a/kafka/protocol/types.py b/kafka/protocol/types.py
index 6b257d3..5aa2e41 100644
--- a/kafka/protocol/types.py
+++ b/kafka/protocol/types.py
@@ -1,45 +1,73 @@
-from struct import pack
+from __future__ import absolute_import
+import abc
+from struct import pack, unpack
-class AbstractField(object):
- def __init__(self, name):
- self.name = name
+from .abstract import AbstractType
-class Int8(AbstractField):
+class Int8(AbstractType):
@classmethod
def encode(cls, value):
return pack('>b', value)
+ @classmethod
+ def decode(cls, data):
+ (value,) = unpack('>b', data.read(1))
+ return value
+
-class Int16(AbstractField):
+class Int16(AbstractType):
@classmethod
def encode(cls, value):
return pack('>h', value)
+ @classmethod
+ def decode(cls, data):
+ (value,) = unpack('>h', data.read(2))
+ return value
-class Int32(AbstractField):
+
+class Int32(AbstractType):
@classmethod
def encode(cls, value):
return pack('>i', value)
+ @classmethod
+ def decode(cls, data):
+ (value,) = unpack('>i', data.read(4))
+ return value
+
-class Int64(AbstractField):
+class Int64(AbstractType):
@classmethod
def encode(cls, value):
return pack('>q', value)
-
-class String(AbstractField):
@classmethod
- def encode(cls, value):
+ def decode(cls, data):
+ (value,) = unpack('>q', data.read(8))
+ return value
+
+
+class String(AbstractType):
+ def __init__(self, encoding='utf-8'):
+ self.encoding = encoding
+
+ def encode(self, value):
if value is None:
return Int16.encode(-1)
- else:
- return Int16.encode(len(value)) + value
+ value = str(value).encode(self.encoding)
+ return Int16.encode(len(value)) + value
+
+ def decode(self, data):
+ length = Int16.decode(data)
+ if length < 0:
+ return None
+ return data.read(length).decode(self.encoding)
-class Bytes(AbstractField):
+class Bytes(AbstractType):
@classmethod
def encode(cls, value):
if value is None:
@@ -47,9 +75,52 @@ class Bytes(AbstractField):
else:
return Int32.encode(len(value)) + value
-
-class Array(object):
@classmethod
- def encode(cls, values):
- # Assume that values are already encoded
- return Int32.encode(len(values)) + b''.join(values)
+ def decode(cls, data):
+ length = Int32.decode(data)
+ if length < 0:
+ return None
+ return data.read(length)
+
+
+class Schema(AbstractType):
+ def __init__(self, *fields):
+ if fields:
+ self.names, self.fields = zip(*fields)
+ else:
+ self.names, self.fields = (), ()
+
+ def encode(self, item):
+ if len(item) != len(self.fields):
+ raise ValueError('Item field count does not match Schema')
+ return b''.join([
+ field.encode(item[i])
+ for i, field in enumerate(self.fields)
+ ])
+
+ def decode(self, data):
+ return tuple([field.decode(data) for field in self.fields])
+
+ def __len__(self):
+ return len(self.fields)
+
+
+class Array(AbstractType):
+ def __init__(self, *array_of):
+ if len(array_of) > 1:
+ self.array_of = Schema(*array_of)
+ elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or
+ issubclass(array_of[0], AbstractType)):
+ self.array_of = array_of[0]
+ else:
+ raise ValueError('Array instantiated with no array_of type')
+
+ def encode(self, items):
+ return b''.join(
+ [Int32.encode(len(items))] +
+ [self.array_of.encode(item) for item in items]
+ )
+
+ def decode(self, data):
+ length = Int32.decode(data)
+ return [self.array_of.decode(data) for _ in range(length)]