# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import datetime, string from packer import Packer from datatypes import serial, timestamp, RangedSet, Struct, UUID from ops import Compound, PRIMITIVE, COMPOUND class CodecException(Exception): pass def direct(t): return lambda x: t def map_str(s): for c in s: if ord(c) >= 0x80: return "vbin16" return "str16" class Codec(Packer): ENCODINGS = { bool: direct("boolean"), unicode: direct("str16"), str: map_str, buffer: direct("vbin32"), int: direct("int64"), long: direct("int64"), float: direct("double"), None.__class__: direct("void"), list: direct("list"), tuple: direct("list"), dict: direct("map"), timestamp: direct("datetime"), datetime.datetime: direct("datetime"), UUID: direct("uuid"), Compound: direct("struct32") } def encoding(self, obj): enc = self._encoding(obj.__class__, obj) if enc is None: raise CodecException("no encoding for %r" % obj) return PRIMITIVE[enc] def _encoding(self, klass, obj): if self.ENCODINGS.has_key(klass): return self.ENCODINGS[klass](obj) for base in klass.__bases__: result = self._encoding(base, obj) if result != None: return result def read_primitive(self, type): return getattr(self, "read_%s" % type.NAME)() def write_primitive(self, type, v): getattr(self, "write_%s" % type.NAME)(v) def read_void(self): return None def write_void(self, v): assert v == None def read_bit(self): return True def write_bit(self, b): if not b: raise ValueError(b) def read_uint8(self): return self.unpack("!B") def write_uint8(self, n): if n < 0 or n > 255: raise CodecException("Cannot encode %d as uint8" % n) return self.pack("!B", n) def read_int8(self): return self.unpack("!b") def write_int8(self, n): if n < -128 or n > 127: raise CodecException("Cannot encode %d as int8" % n) self.pack("!b", n) def read_char(self): return self.unpack("!c") def write_char(self, c): self.pack("!c", c) def read_boolean(self): return self.read_uint8() != 0 def write_boolean(self, b): if b: n = 1 else: n = 0 self.write_uint8(n) def read_uint16(self): return self.unpack("!H") def write_uint16(self, n): if n < 0 or n > 65535: raise CodecException("Cannot encode %d as uint16" % n) self.pack("!H", n) def read_int16(self): return self.unpack("!h") def write_int16(self, n): if n < -32768 or n > 32767: raise CodecException("Cannot encode %d as int16" % n) self.pack("!h", n) def read_uint32(self): return self.unpack("!L") def write_uint32(self, n): if n < 0 or n > 4294967295: raise CodecException("Cannot encode %d as uint32" % n) self.pack("!L", n) def read_int32(self): return self.unpack("!l") def write_int32(self, n): if n < -2147483648 or n > 2147483647: raise CodecException("Cannot encode %d as int32" % n) self.pack("!l", n) def read_float(self): return self.unpack("!f") def write_float(self, f): self.pack("!f", f) def read_sequence_no(self): return serial(self.read_uint32()) def write_sequence_no(self, n): self.write_uint32(n.value) def read_uint64(self): return self.unpack("!Q") def write_uint64(self, n): self.pack("!Q", n) def read_int64(self): return self.unpack("!q") def write_int64(self, n): self.pack("!q", n) def read_datetime(self): return timestamp(self.read_uint64()) def write_datetime(self, t): if isinstance(t, datetime.datetime): t = timestamp(t) self.write_uint64(t) def read_double(self): return self.unpack("!d") def write_double(self, d): self.pack("!d", d) def read_vbin8(self): return self.read(self.read_uint8()) def write_vbin8(self, b): if isinstance(b, buffer): b = str(b) self.write_uint8(len(b)) self.write(b) def read_str8(self): return self.read_vbin8().decode("utf8") def write_str8(self, s): self.write_vbin8(s.encode("utf8")) def read_str16(self): return self.read_vbin16().decode("utf8") def write_str16(self, s): self.write_vbin16(s.encode("utf8")) def read_str16_latin(self): return self.read_vbin16().decode("iso-8859-15") def write_str16_latin(self, s): self.write_vbin16(s.encode("iso-8859-15")) def read_vbin16(self): return self.read(self.read_uint16()) def write_vbin16(self, b): if isinstance(b, buffer): b = str(b) self.write_uint16(len(b)) self.write(b) def read_sequence_set(self): result = RangedSet() size = self.read_uint16() nranges = size/8 while nranges > 0: lower = self.read_sequence_no() upper = self.read_sequence_no() result.add(lower, upper) nranges -= 1 return result def write_sequence_set(self, ss): size = 8*len(ss.ranges) self.write_uint16(size) for range in ss.ranges: self.write_sequence_no(range.lower) self.write_sequence_no(range.upper) def read_vbin32(self): return self.read(self.read_uint32()) def write_vbin32(self, b): if isinstance(b, buffer): b = str(b) self.write_uint32(len(b)) self.write(b) def read_map(self): sc = StringCodec(self.read_vbin32()) if not sc.encoded: return None count = sc.read_uint32() result = {} while sc.encoded: k = sc.read_str8() code = sc.read_uint8() type = PRIMITIVE[code] v = sc.read_primitive(type) result[k] = v return result def _write_map_elem(self, k, v): type = self.encoding(v) sc = StringCodec() sc.write_str8(k) sc.write_uint8(type.CODE) sc.write_primitive(type, v) return sc.encoded def write_map(self, m): sc = StringCodec() if m is not None: sc.write_uint32(len(m)) sc.write(string.joinfields(map(self._write_map_elem, m.keys(), m.values()), "")) self.write_vbin32(sc.encoded) def read_array(self): sc = StringCodec(self.read_vbin32()) if not sc.encoded: return None type = PRIMITIVE[sc.read_uint8()] count = sc.read_uint32() result = [] while count > 0: result.append(sc.read_primitive(type)) count -= 1 return result def write_array(self, a): sc = StringCodec() if a is not None: if len(a) > 0: type = self.encoding(a[0]) else: type = self.encoding(None) sc.write_uint8(type.CODE) sc.write_uint32(len(a)) for o in a: sc.write_primitive(type, o) self.write_vbin32(sc.encoded) def read_list(self): sc = StringCodec(self.read_vbin32()) if not sc.encoded: return None count = sc.read_uint32() result = [] while count > 0: type = PRIMITIVE[sc.read_uint8()] result.append(sc.read_primitive(type)) count -= 1 return result def write_list(self, l): sc = StringCodec() if l is not None: sc.write_uint32(len(l)) for o in l: type = self.encoding(o) sc.write_uint8(type.CODE) sc.write_primitive(type, o) self.write_vbin32(sc.encoded) def read_struct32(self): size = self.read_uint32() code = self.read_uint16() cls = COMPOUND[code] op = cls() self.read_fields(op) return op def write_struct32(self, value): self.write_compound(value) def read_compound(self, cls): size = self.read_size(cls.SIZE) if cls.CODE is not None: code = self.read_uint16() assert code == cls.CODE op = cls() self.read_fields(op) return op def write_compound(self, op): sc = StringCodec() if op.CODE is not None: sc.write_uint16(op.CODE) sc.write_fields(op) self.write_size(op.SIZE, len(sc.encoded)) self.write(sc.encoded) def read_fields(self, op): flags = 0 for i in range(op.PACK): flags |= (self.read_uint8() << 8*i) for i in range(len(op.FIELDS)): f = op.FIELDS[i] if flags & (0x1 << i): if COMPOUND.has_key(f.type): value = self.read_compound(COMPOUND[f.type]) else: value = getattr(self, "read_%s" % f.type)() setattr(op, f.name, value) def write_fields(self, op): flags = 0 for i in range(len(op.FIELDS)): f = op.FIELDS[i] value = getattr(op, f.name) if f.type == "bit": present = value else: present = value != None if present: flags |= (0x1 << i) for i in range(op.PACK): self.write_uint8((flags >> 8*i) & 0xFF) for i in range(len(op.FIELDS)): f = op.FIELDS[i] if flags & (0x1 << i): if COMPOUND.has_key(f.type): enc = self.write_compound else: enc = getattr(self, "write_%s" % f.type) value = getattr(op, f.name) enc(value) def read_size(self, width): if width > 0: attr = "read_uint%d" % (width*8) return getattr(self, attr)() def write_size(self, width, n): if width > 0: attr = "write_uint%d" % (width*8) getattr(self, attr)(n) def read_uuid(self): return UUID(bytes=self.unpack("16s")) def write_uuid(self, s): if isinstance(s, UUID): s = s.bytes self.pack("16s", s) def read_bin128(self): return self.unpack("16s") def write_bin128(self, b): self.pack("16s", b) class StringCodec(Codec): def __init__(self, encoded = ""): self.encoded = encoded def read(self, n): result = self.encoded[:n] self.encoded = self.encoded[n:] return result def write(self, s): self.encoded += s