diff options
Diffstat (limited to 'Lib/pickle.py')
| -rw-r--r-- | Lib/pickle.py | 244 | 
1 files changed, 96 insertions, 148 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py index 1d8185c6c6..5d22fceaf6 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -26,9 +26,10 @@ Misc variables:  from types import FunctionType, BuiltinFunctionType  from copyreg import dispatch_table  from copyreg import _extension_registry, _inverted_registry, _extension_cache -import marshal +from itertools import islice  import sys -import struct +from sys import maxsize +from struct import pack, unpack  import re  import io  import codecs @@ -58,11 +59,6 @@ HIGHEST_PROTOCOL = 3  # there are too many issues with that.  DEFAULT_PROTOCOL = 3 -# Why use struct.pack() for pickling but marshal.loads() for -# unpickling?  struct.pack() is 40% faster than marshal.dumps(), but -# marshal.loads() is twice as fast as struct.unpack()! -mloads = marshal.loads -  class PickleError(Exception):      """A common base class for the other pickling exceptions."""      pass @@ -94,7 +90,7 @@ class _Stop(Exception):  # Jython has PyStringMap; it's a dict subclass with string keys  try:      from org.python.core import PyStringMap -except ImportError: +except ModuleNotFoundError:      PyStringMap = None  # Pickle opcodes.  See pickletools.py for extensive docs.  The listing @@ -231,7 +227,7 @@ class _Pickler:              raise PicklingError("Pickler.__init__() was not called by "                                  "%s.__init__()" % (self.__class__.__name__,))          if self.proto >= 2: -            self.write(PROTO + bytes([self.proto])) +            self.write(PROTO + pack("<B", self.proto))          self.save(obj)          self.write(STOP) @@ -258,20 +254,20 @@ class _Pickler:          self.memo[id(obj)] = memo_len, obj      # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. -    def put(self, i, pack=struct.pack): +    def put(self, i):          if self.bin:              if i < 256: -                return BINPUT + bytes([i]) +                return BINPUT + pack("<B", i)              else:                  return LONG_BINPUT + pack("<I", i)          return PUT + repr(i).encode("ascii") + b'\n'      # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i. -    def get(self, i, pack=struct.pack): +    def get(self, i):          if self.bin:              if i < 256: -                return BINGET + bytes([i]) +                return BINGET + pack("<B", i)              else:                  return LONG_BINGET + pack("<I", i) @@ -286,20 +282,20 @@ class _Pickler:          # Check the memo          x = self.memo.get(id(obj)) -        if x: +        if x is not None:              self.write(self.get(x[0]))              return          # Check the type dispatch table          t = type(obj)          f = self.dispatch.get(t) -        if f: +        if f is not None:              f(self, obj) # Call unbound method with explicit self              return          # Check private dispatch table if any, or else copyreg.dispatch_table          reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) -        if reduce: +        if reduce is not None:              rv = reduce(obj)          else:              # Check for a class with a custom metaclass; treat as regular class @@ -313,11 +309,11 @@ class _Pickler:              # Check for a __reduce_ex__ method, fall back to __reduce__              reduce = getattr(obj, "__reduce_ex__", None) -            if reduce: +            if reduce is not None:                  rv = reduce(self.proto)              else:                  reduce = getattr(obj, "__reduce__", None) -                if reduce: +                if reduce is not None:                      rv = reduce()                  else:                      raise PicklingError("Can't pickle %r object: %r" % @@ -448,12 +444,12 @@ class _Pickler:      def save_bool(self, obj):          if self.proto >= 2: -            self.write(obj and NEWTRUE or NEWFALSE) +            self.write(NEWTRUE if obj else NEWFALSE)          else: -            self.write(obj and TRUE or FALSE) +            self.write(TRUE if obj else FALSE)      dispatch[bool] = save_bool -    def save_long(self, obj, pack=struct.pack): +    def save_long(self, obj):          if self.bin:              # If the int is small enough to fit in a signed 4-byte 2's-comp              # format, we can store it more efficiently than the general @@ -461,39 +457,36 @@ class _Pickler:              # First one- and two-byte unsigned ints:              if obj >= 0:                  if obj <= 0xff: -                    self.write(BININT1 + bytes([obj])) +                    self.write(BININT1 + pack("<B", obj))                      return                  if obj <= 0xffff: -                    self.write(BININT2 + bytes([obj&0xff, obj>>8])) +                    self.write(BININT2 + pack("<H", obj))                      return              # Next check for 4-byte signed ints: -            high_bits = obj >> 31  # note that Python shift sign-extends -            if high_bits == 0 or high_bits == -1: -                # All high bits are copies of bit 2**31, so the value -                # fits in a 4-byte signed int. +            if -0x80000000 <= obj <= 0x7fffffff:                  self.write(BININT + pack("<i", obj))                  return          if self.proto >= 2:              encoded = encode_long(obj)              n = len(encoded)              if n < 256: -                self.write(LONG1 + bytes([n]) + encoded) +                self.write(LONG1 + pack("<B", n) + encoded)              else:                  self.write(LONG4 + pack("<i", n) + encoded)              return          self.write(LONG + repr(obj).encode("ascii") + b'L\n')      dispatch[int] = save_long -    def save_float(self, obj, pack=struct.pack): +    def save_float(self, obj):          if self.bin:              self.write(BINFLOAT + pack('>d', obj))          else:              self.write(FLOAT + repr(obj).encode("ascii") + b'\n')      dispatch[float] = save_float -    def save_bytes(self, obj, pack=struct.pack): +    def save_bytes(self, obj):          if self.proto < 3: -            if len(obj) == 0: +            if not obj: # bytes object is empty                  self.save_reduce(bytes, (), obj=obj)              else:                  self.save_reduce(codecs.encode, @@ -501,13 +494,13 @@ class _Pickler:              return          n = len(obj)          if n < 256: -            self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj)) +            self.write(SHORT_BINBYTES + pack("<B", n) + obj)          else: -            self.write(BINBYTES + pack("<I", n) + bytes(obj)) +            self.write(BINBYTES + pack("<I", n) + obj)          self.memoize(obj)      dispatch[bytes] = save_bytes -    def save_str(self, obj, pack=struct.pack): +    def save_str(self, obj):          if self.bin:              encoded = obj.encode('utf-8', 'surrogatepass')              n = len(encoded) @@ -515,39 +508,36 @@ class _Pickler:          else:              obj = obj.replace("\\", "\\u005c")              obj = obj.replace("\n", "\\u000a") -            self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) + -                       b'\n') +            self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')          self.memoize(obj)      dispatch[str] = save_str      def save_tuple(self, obj): -        write = self.write -        proto = self.proto - -        n = len(obj) -        if n == 0: -            if proto: -                write(EMPTY_TUPLE) +        if not obj: # tuple is empty +            if self.bin: +                self.write(EMPTY_TUPLE)              else: -                write(MARK + TUPLE) +                self.write(MARK + TUPLE)              return +        n = len(obj)          save = self.save          memo = self.memo -        if n <= 3 and proto >= 2: +        if n <= 3 and self.proto >= 2:              for element in obj:                  save(element)              # Subtle.  Same as in the big comment below.              if id(obj) in memo:                  get = self.get(memo[id(obj)][0]) -                write(POP * n + get) +                self.write(POP * n + get)              else: -                write(_tuplesize2code[n]) +                self.write(_tuplesize2code[n])                  self.memoize(obj)              return          # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple          # has more than 3 elements. +        write = self.write          write(MARK)          for element in obj:              save(element) @@ -561,25 +551,23 @@ class _Pickler:              # could have been done in the "for element" loop instead, but              # recursive tuples are a rare thing.              get = self.get(memo[id(obj)][0]) -            if proto: +            if self.bin:                  write(POP_MARK + get)              else:   # proto 0 -- POP_MARK not available                  write(POP * (n+1) + get)              return          # No recursion. -        self.write(TUPLE) +        write(TUPLE)          self.memoize(obj)      dispatch[tuple] = save_tuple      def save_list(self, obj): -        write = self.write -          if self.bin: -            write(EMPTY_LIST) +            self.write(EMPTY_LIST)          else:   # proto 0 -- can't use EMPTY_LIST -            write(MARK + LIST) +            self.write(MARK + LIST)          self.memoize(obj)          self._batch_appends(obj) @@ -599,17 +587,9 @@ class _Pickler:                  write(APPEND)              return -        items = iter(items) -        r = range(self._BATCHSIZE) -        while items is not None: -            tmp = [] -            for i in r: -                try: -                    x = next(items) -                    tmp.append(x) -                except StopIteration: -                    items = None -                    break +        it = iter(items) +        while True: +            tmp = list(islice(it, self._BATCHSIZE))              n = len(tmp)              if n > 1:                  write(MARK) @@ -620,14 +600,14 @@ class _Pickler:                  save(tmp[0])                  write(APPEND)              # else tmp is empty, and we're done +            if n < self._BATCHSIZE: +                return      def save_dict(self, obj): -        write = self.write -          if self.bin: -            write(EMPTY_DICT) +            self.write(EMPTY_DICT)          else:   # proto 0 -- can't use EMPTY_DICT -            write(MARK + DICT) +            self.write(MARK + DICT)          self.memoize(obj)          self._batch_setitems(obj.items()) @@ -648,16 +628,9 @@ class _Pickler:                  write(SETITEM)              return -        items = iter(items) -        r = range(self._BATCHSIZE) -        while items is not None: -            tmp = [] -            for i in r: -                try: -                    tmp.append(next(items)) -                except StopIteration: -                    items = None -                    break +        it = iter(items) +        while True: +            tmp = list(islice(it, self._BATCHSIZE))              n = len(tmp)              if n > 1:                  write(MARK) @@ -671,8 +644,10 @@ class _Pickler:                  save(v)                  write(SETITEM)              # else tmp is empty, and we're done +            if n < self._BATCHSIZE: +                return -    def save_global(self, obj, name=None, pack=struct.pack): +    def save_global(self, obj, name=None):          write = self.write          memo = self.memo @@ -702,9 +677,9 @@ class _Pickler:              if code:                  assert code > 0                  if code <= 0xff: -                    write(EXT1 + bytes([code])) +                    write(EXT1 + pack("<B", code))                  elif code <= 0xffff: -                    write(EXT2 + bytes([code&0xff, code>>8])) +                    write(EXT2 + pack("<H", code))                  else:                      write(EXT4 + pack("<i", code))                  return @@ -732,25 +707,6 @@ class _Pickler:      dispatch[BuiltinFunctionType] = save_global      dispatch[type] = save_global -# Pickling helpers - -def _keep_alive(x, memo): -    """Keeps a reference to the object x in the memo. - -    Because we remember objects by their id, we have -    to assure that possibly temporary objects are kept -    alive by referencing them. -    We store a reference at the id of the memo, which should -    normally not be used unless someone tries to deepcopy -    the memo itself... -    """ -    try: -        memo[id(memo)].append(x) -    except KeyError: -        # aha, this is the first one :-) -        memo[id(memo)]=[x] - -  # A cache for whichmodule(), mapping a function object to the name of  # the module in which the function was found. @@ -832,7 +788,7 @@ class _Unpickler:          read = self.read          dispatch = self.dispatch          try: -            while 1: +            while True:                  key = read(1)                  if not key:                      raise EOFError @@ -862,7 +818,7 @@ class _Unpickler:      dispatch = {}      def load_proto(self): -        proto = ord(self.read(1)) +        proto = self.read(1)[0]          if not 0 <= proto <= HIGHEST_PROTOCOL:              raise ValueError("unsupported pickle protocol: %d" % proto)          self.proto = proto @@ -897,43 +853,40 @@ class _Unpickler:          elif data == TRUE[1:]:              val = True          else: -            try: -                val = int(data, 0) -            except ValueError: -                val = int(data, 0) +            val = int(data, 0)          self.append(val)      dispatch[INT[0]] = load_int      def load_binint(self): -        self.append(mloads(b'i' + self.read(4))) +        self.append(unpack('<i', self.read(4))[0])      dispatch[BININT[0]] = load_binint      def load_binint1(self): -        self.append(ord(self.read(1))) +        self.append(self.read(1)[0])      dispatch[BININT1[0]] = load_binint1      def load_binint2(self): -        self.append(mloads(b'i' + self.read(2) + b'\000\000')) +        self.append(unpack('<H', self.read(2))[0])      dispatch[BININT2[0]] = load_binint2      def load_long(self): -        val = self.readline()[:-1].decode("ascii") -        if val and val[-1] == 'L': +        val = self.readline()[:-1] +        if val and val[-1] == b'L'[0]:              val = val[:-1]          self.append(int(val, 0))      dispatch[LONG[0]] = load_long      def load_long1(self): -        n = ord(self.read(1)) +        n = self.read(1)[0]          data = self.read(n)          self.append(decode_long(data))      dispatch[LONG1[0]] = load_long1      def load_long4(self): -        n = mloads(b'i' + self.read(4)) +        n, = unpack('<i', self.read(4))          if n < 0:              # Corrupt or hostile pickle -- we never write one like this -            raise UnpicklingError("LONG pickle has negative byte count"); +            raise UnpicklingError("LONG pickle has negative byte count")          data = self.read(n)          self.append(decode_long(data))      dispatch[LONG4[0]] = load_long4 @@ -942,39 +895,36 @@ class _Unpickler:          self.append(float(self.readline()[:-1]))      dispatch[FLOAT[0]] = load_float -    def load_binfloat(self, unpack=struct.unpack): +    def load_binfloat(self):          self.append(unpack('>d', self.read(8))[0])      dispatch[BINFLOAT[0]] = load_binfloat      def load_string(self): -        orig = self.readline() -        rep = orig[:-1] -        for q in (b'"', b"'"): # double or single quote -            if rep.startswith(q): -                if len(rep) < 2 or not rep.endswith(q): -                    raise ValueError("insecure string pickle") -                rep = rep[len(q):-len(q)] -                break +        data = self.readline()[:-1] +        # Strip outermost quotes +        if len(data) >= 2 and data[0] == data[-1] and data[0] in b'"\'': +            data = data[1:-1]          else: -            raise ValueError("insecure string pickle: %r" % orig) -        self.append(codecs.escape_decode(rep)[0] +            raise UnpicklingError("the STRING opcode argument must be quoted") +        self.append(codecs.escape_decode(data)[0]                      .decode(self.encoding, self.errors))      dispatch[STRING[0]] = load_string      def load_binstring(self):          # Deprecated BINSTRING uses signed 32-bit length -        len = mloads(b'i' + self.read(4)) +        len, = unpack('<i', self.read(4))          if len < 0: -            raise UnpicklingError("BINSTRING pickle has negative byte count"); +            raise UnpicklingError("BINSTRING pickle has negative byte count")          data = self.read(len)          value = str(data, self.encoding, self.errors)          self.append(value)      dispatch[BINSTRING[0]] = load_binstring -    def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize): +    def load_binbytes(self):          len, = unpack('<I', self.read(4))          if len > maxsize: -            raise UnpicklingError("BINBYTES exceeds system's maximum size of %d bytes" % maxsize); +            raise UnpicklingError("BINBYTES exceeds system's maximum size " +                                  "of %d bytes" % maxsize)          self.append(self.read(len))      dispatch[BINBYTES[0]] = load_binbytes @@ -982,23 +932,24 @@ class _Unpickler:          self.append(str(self.readline()[:-1], 'raw-unicode-escape'))      dispatch[UNICODE[0]] = load_unicode -    def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize): +    def load_binunicode(self):          len, = unpack('<I', self.read(4))          if len > maxsize: -            raise UnpicklingError("BINUNICODE exceeds system's maximum size of %d bytes" % maxsize); +            raise UnpicklingError("BINUNICODE exceeds system's maximum size " +                                  "of %d bytes" % maxsize)          self.append(str(self.read(len), 'utf-8', 'surrogatepass'))      dispatch[BINUNICODE[0]] = load_binunicode      def load_short_binstring(self): -        len = ord(self.read(1)) -        data = bytes(self.read(len)) +        len = self.read(1)[0] +        data = self.read(len)          value = str(data, self.encoding, self.errors)          self.append(value)      dispatch[SHORT_BINSTRING[0]] = load_short_binstring      def load_short_binbytes(self): -        len = ord(self.read(1)) -        self.append(bytes(self.read(len))) +        len = self.read(1)[0] +        self.append(self.read(len))      dispatch[SHORT_BINBYTES[0]] = load_short_binbytes      def load_tuple(self): @@ -1037,12 +988,9 @@ class _Unpickler:      def load_dict(self):          k = self.marker() -        d = {}          items = self.stack[k+1:] -        for i in range(0, len(items), 2): -            key = items[i] -            value = items[i+1] -            d[key] = value +        d = {items[i]: items[i+1] +             for i in range(0, len(items), 2)}          self.stack[k:] = [d]      dispatch[DICT[0]] = load_dict @@ -1094,17 +1042,17 @@ class _Unpickler:      dispatch[GLOBAL[0]] = load_global      def load_ext1(self): -        code = ord(self.read(1)) +        code = self.read(1)[0]          self.get_extension(code)      dispatch[EXT1[0]] = load_ext1      def load_ext2(self): -        code = mloads(b'i' + self.read(2) + b'\000\000') +        code, = unpack('<H', self.read(2))          self.get_extension(code)      dispatch[EXT2[0]] = load_ext2      def load_ext4(self): -        code = mloads(b'i' + self.read(4)) +        code, = unpack('<i', self.read(4))          self.get_extension(code)      dispatch[EXT4[0]] = load_ext4 @@ -1118,7 +1066,7 @@ class _Unpickler:          if not key:              if code <= 0: # note that 0 is forbidden                  # Corrupt or hostile pickle. -                raise UnpicklingError("EXT specifies code <= 0"); +                raise UnpicklingError("EXT specifies code <= 0")              raise ValueError("unregistered extension code %d" % code)          obj = self.find_class(*key)          _extension_cache[code] = obj @@ -1172,7 +1120,7 @@ class _Unpickler:          self.append(self.memo[i])      dispatch[BINGET[0]] = load_binget -    def load_long_binget(self, unpack=struct.unpack): +    def load_long_binget(self):          i, = unpack('<I', self.read(4))          self.append(self.memo[i])      dispatch[LONG_BINGET[0]] = load_long_binget @@ -1191,7 +1139,7 @@ class _Unpickler:          self.memo[i] = self.stack[-1]      dispatch[BINPUT[0]] = load_binput -    def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize): +    def load_long_binput(self):          i, = unpack('<I', self.read(4))          if i > maxsize:              raise ValueError("negative LONG_BINPUT argument") @@ -1242,7 +1190,7 @@ class _Unpickler:          state = stack.pop()          inst = stack[-1]          setstate = getattr(inst, "__setstate__", None) -        if setstate: +        if setstate is not None:              setstate(state)              return          slotstate = None @@ -1348,7 +1296,7 @@ def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):  # Use the faster _pickle if possible  try:      from _pickle import * -except ImportError: +except ModuleNotFoundError:      Pickler, Unpickler = _Pickler, _Unpickler  # Doctest  | 
