diff options
author | Dana Powers <dana.powers@rd.io> | 2015-11-28 19:41:06 +0800 |
---|---|---|
committer | Zack Dever <zack.dever@rd.io> | 2015-12-04 11:25:39 -0800 |
commit | a85e09df89a43de5b659a0fa4ed35bec37c60e04 (patch) | |
tree | a539af32fe502006c1f35b96d8ae36225292f7a5 | |
parent | e24a4d5f5252d6f97ac586e328b95779ef83f4b6 (diff) | |
download | kafka-python-a85e09df89a43de5b659a0fa4ed35bec37c60e04.tar.gz |
Rework protocol type definition: AbstractType, Schema, Struct
-rw-r--r-- | kafka/protocol/abstract.py | 13 | ||||
-rw-r--r-- | kafka/protocol/api.py | 309 | ||||
-rw-r--r-- | kafka/protocol/commit.py | 111 | ||||
-rw-r--r-- | kafka/protocol/fetch.py | 30 | ||||
-rw-r--r-- | kafka/protocol/message.py | 67 | ||||
-rw-r--r-- | kafka/protocol/metadata.py | 28 | ||||
-rw-r--r-- | kafka/protocol/offset.py | 32 | ||||
-rw-r--r-- | kafka/protocol/produce.py | 81 | ||||
-rw-r--r-- | kafka/protocol/struct.py | 52 | ||||
-rw-r--r-- | kafka/protocol/types.py | 109 |
10 files changed, 461 insertions, 371 deletions
diff --git a/kafka/protocol/abstract.py b/kafka/protocol/abstract.py new file mode 100644 index 0000000..9c53c8c --- /dev/null +++ b/kafka/protocol/abstract.py @@ -0,0 +1,13 @@ +import abc + + +class AbstractType(object): + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def encode(cls, value): + pass + + @abc.abstractmethod + def decode(cls, data): + pass diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 8ea820b..0c23437 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -1,301 +1,16 @@ -import struct +from .struct import Struct +from .types import Int16, Int32, String, Schema -from .types import ( - Int8, Int16, Int32, Int64, Bytes, String, Array -) -from ..util import crc32 +class RequestHeader(Struct): + SCHEMA = Schema( + ('api_key', Int16), + ('api_version', Int16), + ('correlation_id', Int32), + ('client_id', String('utf-8')) + ) -class Message(object): - MAGIC_BYTE = 0 - __slots__ = ('magic', 'attributes', 'key', 'value') - - def __init__(self, value, key=None, magic=0, attributes=0): - self.magic = magic - self.attributes = attributes - self.key = key - self.value = value - - def encode(self): - message = ( - Int8.encode(self.magic) + - Int8.encode(self.attributes) + - Bytes.encode(self.key) + - Bytes.encode(self.value) + def __init__(self, request, correlation_id=0, client_id='kafka-python'): + super(RequestHeader, self).__init__( + request.API_KEY, request.API_VERSION, correlation_id, client_id ) - return ( - struct.pack('>I', crc32(message)) + - message - ) - - -class MessageSet(object): - - @staticmethod - def _encode_one(message): - encoded = message.encode() - return (Int64.encode(0) + Int32.encode(len(encoded)) + encoded) - - @staticmethod - def encode(messages): - return b''.join(map(MessageSet._encode_one, messages)) - - -class AbstractRequestResponse(object): - @classmethod - def encode(cls, message): - return Int32.encode(len(message)) + message - - -class AbstractRequest(AbstractRequestResponse): - @classmethod - def encode(cls, request, correlation_id=0, client_id='kafka-python'): - request = (Int16.encode(cls.API_KEY) + - Int16.encode(cls.API_VERSION) + - Int32.encode(correlation_id) + - String.encode(client_id) + - request) - return super(AbstractRequest, cls).encode(request) - - -class FetchRequest(AbstractRequest): - API_KEY = 1 - API_VERSION = 0 - __slots__ = ('replica_id', 'max_wait_time', 'min_bytes', 'topic_partition_offsets') - - def __init__(self, topic_partition_offsets, - max_wait_time=-1, min_bytes=0, replica_id=-1): - """ - topic_partition_offsets is a dict of dicts of (offset, max_bytes) tuples - { - "TopicFoo": { - 0: (1234, 1048576), - 1: (1324, 1048576) - } - } - """ - self.topic_partition_offsets = topic_partition_offsets - self.max_wait_time = max_wait_time - self.min_bytes = min_bytes - self.replica_id = replica_id - - def encode(self): - request = ( - Int32.encode(self.replica_id) + - Int32.encode(self.max_wait_time) + - Int32.encode(self.min_bytes) + - Array.encode([( - String.encode(topic) + - Array.encode([( - Int32.encode(partition) + - Int64.encode(offset) + - Int32.encode(max_bytes) - ) for partition, (offset, max_bytes) in partitions.iteritems()]) - ) for topic, partitions in self.topic_partition_offsets.iteritems()])) - return super(FetchRequest, self).encode(request) - - -class OffsetRequest(AbstractRequest): - API_KEY = 2 - API_VERSION = 0 - __slots__ = ('replica_id', 'topic_partition_times') - - def __init__(self, topic_partition_times, replica_id=-1): - """ - topic_partition_times is a dict of dicts of (time, max_offsets) tuples - { - "TopicFoo": { - 0: (-1, 1), - 1: (-1, 1) - } - } - """ - self.topic_partition_times = topic_partition_times - self.replica_id = replica_id - - def encode(self): - request = ( - Int32.encode(self.replica_id) + - Array.encode([( - String.encode(topic) + - Array.encode([( - Int32.encode(partition) + - Int64.encode(time) + - Int32.encode(max_offsets) - ) for partition, (time, max_offsets) in partitions.iteritems()]) - ) for topic, partitions in self.topic_partition_times.iteritems()])) - return super(OffsetRequest, self).encode(request) - - -class MetadataRequest(AbstractRequest): - API_KEY = 3 - API_VERSION = 0 - __slots__ = ('topics') - - def __init__(self, *topics): - self.topics = topics - - def encode(self): - request = Array.encode(map(String.encode, self.topics)) - return super(MetadataRequest, self).encode(request) - - -# Non-user facing control APIs 4-7 - - -class OffsetCommitRequestV0(AbstractRequest): - API_KEY = 8 - API_VERSION = 0 - __slots__ = ('consumer_group_id', 'offsets') - - def __init__(self, consumer_group_id, offsets): - """ - offsets is a dict of dicts of (offset, metadata) tuples - { - "TopicFoo": { - 0: (1234, ""), - 1: (1243, "") - } - } - """ - self.consumer_group_id = consumer_group_id - self.offsets = offsets - - def encode(self): - request = ( - String.encode(self.consumer_group_id) + - Array.encode([( - String.encode(topic) + - Array.encode([( - Int32.encode(partition) + - Int64.encode(offset) + - String.encode(metadata) - ) for partition, (offset, metadata) in partitions.iteritems()]) - ) for topic, partitions in self.offsets.iteritems()])) - return super(OffsetCommitRequestV0, self).encode(request) - - -class OffsetCommitRequestV1(AbstractRequest): - API_KEY = 8 - API_VERSION = 1 - __slots__ = ('consumer_group_id', 'consumer_group_generation_id', - 'consumer_id', 'offsets') - - def __init__(self, consumer_group_id, consumer_group_generation_id, - consumer_id, offsets): - """ - offsets is a dict of dicts of (offset, timestamp, metadata) tuples - { - "TopicFoo": { - 0: (1234, 1448198827, ""), - 1: (1243, 1448198827, "") - } - } - """ - self.consumer_group_id = consumer_group_id - self.consumer_group_generation_id = consumer_group_generation_id - self.consumer_id = consumer_id - self.offsets = offsets - - def encode(self): - request = ( - String.encode(self.consumer_group_id) + - Int32.encode(self.consumer_group_generation_id) + - String.encode(self.consumer_id) + - Array.encode([( - String.encode(topic) + - Array.encode([( - Int32.encode(partition) + - Int64.encode(offset) + - Int64.encode(timestamp) + - String.encode(metadata) - ) for partition, (offset, timestamp, metadata) in partitions.iteritems()]) - ) for topic, partitions in self.offsets.iteritems()])) - return super(OffsetCommitRequestV1, self).encode(request) - - -class OffsetCommitRequest(AbstractRequest): - API_KEY = 8 - API_VERSION = 2 - __slots__ = ('consumer_group_id', 'consumer_group_generation_id', - 'consumer_id', 'retention_time', 'offsets') - - def __init__(self, consumer_group_id, consumer_group_generation_id, - consumer_id, retention_time, offsets): - """ - offsets is a dict of dicts of (offset, metadata) tuples - { - "TopicFoo": { - 0: (1234, ""), - 1: (1243, "") - } - } - """ - self.consumer_group_id = consumer_group_id - self.consumer_group_generation_id = consumer_group_generation_id - self.consumer_id = consumer_id - self.retention_time = retention_time - self.offsets = offsets - - def encode(self): - request = ( - String.encode(self.consumer_group_id) + - Int32.encode(self.consumer_group_generation_id) + - String.encode(self.consumer_id) + - Int64.encode(self.retention_time) + - Array.encode([( - String.encode(topic) + - Array.encode([( - Int32.encode(partition) + - Int64.encode(offset) + - String.encode(metadata) - ) for partition, (offset, timestamp, metadata) in partitions.iteritems()]) - ) for topic, partitions in self.offsets.iteritems()])) - return super(OffsetCommitRequest, self).encode(request) - - -class OffsetFetchRequestV0(AbstractRequest): - API_KEY = 9 - API_VERSION = 0 - __slots__ = ('consumer_group', 'topic_partitions') - - def __init__(self, consumer_group, topic_partitions): - """ - offsets is a dict of lists of partition ints - { - "TopicFoo": [0, 1, 2] - } - """ - self.consumer_group = consumer_group - self.topic_partitions = topic_partitions - - def encode(self): - request = ( - String.encode(self.consumer_group) + - Array.encode([( - String.encode(topic) + - Array.encode([Int32.encode(partition) for partition in partitions]) - ) for topic, partitions in self.topic_partitions.iteritems()]) - ) - return super(OffsetFetchRequest, self).encode(request) - - -class OffsetFetchRequest(OffsetFetchRequestV0): - """Identical to V0, but offsets fetched from kafka storage not zookeeper""" - API_VERSION = 1 - - -class GroupCoordinatorRequest(AbstractRequest): - API_KEY = 10 - API_VERSION = 0 - __slots__ = ('group_id',) - - def __init__(self, group_id): - self.group_id = group_id - - def encode(self): - request = String.encode(self.group_id) - return super(GroupCoordinatorRequest, self).encode(request) - - - diff --git a/kafka/protocol/commit.py b/kafka/protocol/commit.py new file mode 100644 index 0000000..5ba0227 --- /dev/null +++ b/kafka/protocol/commit.py @@ -0,0 +1,111 @@ +from .struct import Struct +from .types import Array, Int16, Int32, Int64, Schema, String + + +class OffsetCommitRequest_v2(Struct): + API_KEY = 8 + API_VERSION = 2 # added retention_time, dropped timestamp + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('consumer_group_generation_id', Int32), + ('consumer_id', String('utf-8')), + ('retention_time', Int64), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')))))) + ) + + +class OffsetCommitRequest_v1(Struct): + API_KEY = 8 + API_VERSION = 1 # Kafka-backed storage + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('consumer_group_generation_id', Int32), + ('consumer_id', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('timestamp', Int64), + ('metadata', String('utf-8')))))) + ) + + +class OffsetCommitRequest_v0(Struct): + API_KEY = 8 + API_VERSION = 0 # Zookeeper-backed storage + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')))))) + ) + + +class OffsetCommitResponse(Struct): + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16))))) + ) + + +class OffsetFetchRequest_v1(Struct): + API_KEY = 9 + API_VERSION = 1 # kafka-backed storage + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array(Int32)))) + ) + + +class OffsetFetchRequest_v0(Struct): + API_KEY = 9 + API_VERSION = 0 # zookeeper-backed storage + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array(Int32)))) + ) + + +class OffsetFetchResponse(Struct): + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')), + ('error_code', Int16))))) + ) + + +class GroupCoordinatorRequest(Struct): + API_KEY = 10 + API_VERSION = 0 + SCHEMA = Schema( + ('consumer_group', String('utf-8')) + ) + + +class GroupCoordinatorResponse(Struct): + SCHEMA = Schema( + ('error_code', Int16), + ('coordinator_id', Int32), + ('host', String('utf-8')), + ('port', Int32) + ) diff --git a/kafka/protocol/fetch.py b/kafka/protocol/fetch.py new file mode 100644 index 0000000..c6d60cc --- /dev/null +++ b/kafka/protocol/fetch.py @@ -0,0 +1,30 @@ +from .message import MessageSet +from .struct import Struct +from .types import Array, Int16, Int32, Int64, Schema, String + + +class FetchRequest(Struct): + API_KEY = 1 + API_VERSION = 0 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('max_bytes', Int32))))) + ) + +class FetchResponse(Struct): + SCHEMA = Schema( + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('message_set', MessageSet))))) + ) diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py new file mode 100644 index 0000000..26f5ef6 --- /dev/null +++ b/kafka/protocol/message.py @@ -0,0 +1,67 @@ +from .struct import Struct +from .types import ( + Int8, Int16, Int32, Int64, Bytes, String, Array, Schema, AbstractType +) +from ..util import crc32 + + +class Message(Struct): + SCHEMA = Schema( + ('crc', Int32), + ('magic', Int8), + ('attributes', Int8), + ('key', Bytes), + ('value', Bytes) + ) + + def __init__(self, value, key=None, magic=0, attributes=0, crc=0): + self.crc = crc + self.magic = magic + self.attributes = attributes + self.key = key + self.value = value + self.encode = self._encode_self + + def _encode_self(self, recalc_crc=True): + message = Message.SCHEMA.encode( + (self.crc, self.magic, self.attributes, self.key, self.value) + ) + if not recalc_crc: + return message + self.crc = crc32(message[4:]) + return self.SCHEMA.fields[0].encode(self.crc) + message[4:] + + +class MessageSet(AbstractType): + ITEM = Schema( + ('offset', Int64), + ('message_size', Int32), + ('message', Message.SCHEMA) + ) + + @classmethod + def encode(cls, items, size=True, recalc_message_size=True): + encoded_values = [] + for (offset, message_size, message) in items: + if isinstance(message, Message): + encoded_message = message.encode() + else: + encoded_message = cls.ITEM.fields[2].encode(message) + if recalc_message_size: + message_size = len(encoded_message) + encoded_values.append(cls.ITEM.fields[0].encode(offset)) + encoded_values.append(cls.ITEM.fields[1].encode(message_size)) + encoded_values.append(encoded_message) + encoded = b''.join(encoded_values) + if not size: + return encoded + return Int32.encode(len(encoded)) + encoded + + @classmethod + def decode(cls, data): + size = Int32.decode(data) + end = data.tell() + size + items = [] + while data.tell() < end: + items.append(cls.ITEM.decode(data)) + return items diff --git a/kafka/protocol/metadata.py b/kafka/protocol/metadata.py new file mode 100644 index 0000000..b35e7ef --- /dev/null +++ b/kafka/protocol/metadata.py @@ -0,0 +1,28 @@ +from .struct import Struct +from .types import Array, Int16, Int32, Schema, String + + +class MetadataRequest(Struct): + API_KEY = 3 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array(String('utf-8'))) + ) + + +class MetadataResponse(Struct): + SCHEMA = Schema( + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32))), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)))))) + ) diff --git a/kafka/protocol/offset.py b/kafka/protocol/offset.py new file mode 100644 index 0000000..942bdbf --- /dev/null +++ b/kafka/protocol/offset.py @@ -0,0 +1,32 @@ +from .struct import Struct +from .types import Array, Int16, Int32, Int64, Schema, String + + +class OffsetRequest(Struct): + API_KEY = 2 + API_VERSION = 0 + SCHEMA = Schema( + ('replica_id', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('time', Int64), + ('max_offsets', Int32))))) + ) + DEFAULTS = { + 'replica_id': -1 + } + + +class OffsetResponse(Struct): + API_KEY = 2 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offsets', Array(Int64)))))) + ) diff --git a/kafka/protocol/produce.py b/kafka/protocol/produce.py index b875397..532a702 100644 --- a/kafka/protocol/produce.py +++ b/kafka/protocol/produce.py @@ -1,59 +1,30 @@ -from .api import AbstractRequest, AbstractResponse, MessageSet -from .types import Int8, Int16, Int32, Int64, Bytes, String, Array +from .message import MessageSet +from .struct import Struct +from .types import Int8, Int16, Int32, Int64, Bytes, String, Array, Schema -class ProduceRequest(AbstractRequest): +class ProduceRequest(Struct): API_KEY = 0 API_VERSION = 0 - __slots__ = ('required_acks', 'timeout', 'topic_partition_messages', 'compression') - - def __init__(self, topic_partition_messages, - required_acks=-1, timeout=1000, compression=None): - """ - topic_partition_messages is a dict of dicts of lists (of messages) - { - "TopicFoo": { - 0: [ - Message('foo'), - Message('bar') - ], - 1: [ - Message('fizz'), - Message('buzz') - ] - } - } - """ - self.required_acks = required_acks - self.timeout = timeout - self.topic_partition_messages = topic_partition_messages - self.compression = compression - - @staticmethod - def _encode_messages(partition, messages, compression): - message_set = MessageSet.encode(messages) - - if compression: - # compress message_set data and re-encode as single message - # then wrap single compressed message in a new message_set - pass - - return (Int32.encode(partition) + - Int32.encode(len(message_set)) + - message_set) - - def encode(self): - request = ( - Int16.encode(self.required_acks) + - Int32.encode(self.timeout) + - Array.encode([( - String.encode(topic) + - Array.encode([ - self._encode_messages(partition, messages, self.compression) - for partition, messages in partitions.iteritems()]) - ) for topic, partitions in self.topic_partition_messages.iteritems()]) - ) - return super(ProduceRequest, self).encode(request) - - - + SCHEMA = Schema( + ('required_acks', Int16), + ('timeout', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('messages', MessageSet))))) + ) + + +class ProduceResponse(Struct): + API_KEY = 0 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64))))) + ) diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py new file mode 100644 index 0000000..77f5fe7 --- /dev/null +++ b/kafka/protocol/struct.py @@ -0,0 +1,52 @@ +from collections import namedtuple +from io import BytesIO + +from .abstract import AbstractType +from .types import Schema + + +class Struct(AbstractType): + SCHEMA = Schema() + + def __init__(self, *args, **kwargs): + if len(args) == len(self.SCHEMA.fields): + for i, name in enumerate(self.SCHEMA.names): + self.__dict__[name] = args[i] + elif len(args) > 0: + raise ValueError('Args must be empty or mirror schema') + else: + self.__dict__.update(kwargs) + + # overloading encode() to support both class and instance + self.encode = self._encode_self + + @classmethod + def encode(cls, item): + bits = [] + for i, field in enumerate(cls.SCHEMA.fields): + bits.append(field.encode(item[i])) + return b''.join(bits) + + def _encode_self(self): + return self.SCHEMA.encode( + [self.__dict__[name] for name in self.SCHEMA.names] + ) + + @classmethod + def decode(cls, data): + if isinstance(data, bytes): + data = BytesIO(data) + return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + + def __repr__(self): + key_vals =['%s=%r' % (name, self.__dict__[name]) + for name in self.SCHEMA.names] + return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' + +""" +class MetaStruct(type): + def __new__(cls, clsname, bases, dct): + nt = namedtuple(clsname, [name for (name, _) in dct['SCHEMA']]) + bases = tuple([Struct, nt] + list(bases)) + return super(MetaStruct, cls).__new__(cls, clsname, bases, dct) +""" diff --git a/kafka/protocol/types.py b/kafka/protocol/types.py index 6b257d3..5aa2e41 100644 --- a/kafka/protocol/types.py +++ b/kafka/protocol/types.py @@ -1,45 +1,73 @@ -from struct import pack +from __future__ import absolute_import +import abc +from struct import pack, unpack -class AbstractField(object): - def __init__(self, name): - self.name = name +from .abstract import AbstractType -class Int8(AbstractField): +class Int8(AbstractType): @classmethod def encode(cls, value): return pack('>b', value) + @classmethod + def decode(cls, data): + (value,) = unpack('>b', data.read(1)) + return value + -class Int16(AbstractField): +class Int16(AbstractType): @classmethod def encode(cls, value): return pack('>h', value) + @classmethod + def decode(cls, data): + (value,) = unpack('>h', data.read(2)) + return value -class Int32(AbstractField): + +class Int32(AbstractType): @classmethod def encode(cls, value): return pack('>i', value) + @classmethod + def decode(cls, data): + (value,) = unpack('>i', data.read(4)) + return value + -class Int64(AbstractField): +class Int64(AbstractType): @classmethod def encode(cls, value): return pack('>q', value) - -class String(AbstractField): @classmethod - def encode(cls, value): + def decode(cls, data): + (value,) = unpack('>q', data.read(8)) + return value + + +class String(AbstractType): + def __init__(self, encoding='utf-8'): + self.encoding = encoding + + def encode(self, value): if value is None: return Int16.encode(-1) - else: - return Int16.encode(len(value)) + value + value = str(value).encode(self.encoding) + return Int16.encode(len(value)) + value + + def decode(self, data): + length = Int16.decode(data) + if length < 0: + return None + return data.read(length).decode(self.encoding) -class Bytes(AbstractField): +class Bytes(AbstractType): @classmethod def encode(cls, value): if value is None: @@ -47,9 +75,52 @@ class Bytes(AbstractField): else: return Int32.encode(len(value)) + value - -class Array(object): @classmethod - def encode(cls, values): - # Assume that values are already encoded - return Int32.encode(len(values)) + b''.join(values) + def decode(cls, data): + length = Int32.decode(data) + if length < 0: + return None + return data.read(length) + + +class Schema(AbstractType): + def __init__(self, *fields): + if fields: + self.names, self.fields = zip(*fields) + else: + self.names, self.fields = (), () + + def encode(self, item): + if len(item) != len(self.fields): + raise ValueError('Item field count does not match Schema') + return b''.join([ + field.encode(item[i]) + for i, field in enumerate(self.fields) + ]) + + def decode(self, data): + return tuple([field.decode(data) for field in self.fields]) + + def __len__(self): + return len(self.fields) + + +class Array(AbstractType): + def __init__(self, *array_of): + if len(array_of) > 1: + self.array_of = Schema(*array_of) + elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or + issubclass(array_of[0], AbstractType)): + self.array_of = array_of[0] + else: + raise ValueError('Array instantiated with no array_of type') + + def encode(self, items): + return b''.join( + [Int32.encode(len(items))] + + [self.array_of.encode(item) for item in items] + ) + + def decode(self, data): + length = Int32.decode(data) + return [self.array_of.decode(data) for _ in range(length)] |