summaryrefslogtreecommitdiff
path: root/Lib/pickle.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/pickle.py')
-rw-r--r--Lib/pickle.py47
1 files changed, 29 insertions, 18 deletions
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 3b139844c4..040ecb245f 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -258,24 +258,20 @@ class _Unframer:
# Tools used for pickling.
-def _getattribute(obj, name, allow_qualname=False):
- dotted_path = name.split(".")
- if not allow_qualname and len(dotted_path) > 1:
- raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
- "use protocols >= 4 to enable support"
- .format(name, obj))
- for subpath in dotted_path:
+def _getattribute(obj, name):
+ for subpath in name.split('.'):
if subpath == '<locals>':
raise AttributeError("Can't get local attribute {!r} on {!r}"
.format(name, obj))
try:
+ parent = obj
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError("Can't get attribute {!r} on {!r}"
.format(name, obj))
- return obj
+ return obj, parent
-def whichmodule(obj, name, allow_qualname=False):
+def whichmodule(obj, name):
"""Find the module an object belong to."""
module_name = getattr(obj, '__module__', None)
if module_name is not None:
@@ -286,7 +282,7 @@ def whichmodule(obj, name, allow_qualname=False):
if module_name == '__main__' or module is None:
continue
try:
- if _getattribute(module, name, allow_qualname) is obj:
+ if _getattribute(module, name)[0] is obj:
return module_name
except AttributeError:
pass
@@ -533,7 +529,11 @@ class _Pickler:
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
- self.write(PERSID + str(pid).encode("ascii") + b'\n')
+ try:
+ self.write(PERSID + str(pid).encode("ascii") + b'\n')
+ except UnicodeEncodeError:
+ raise PicklingError(
+ "persistent IDs in protocol 0 must be ASCII strings")
def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, obj=None):
@@ -899,16 +899,16 @@ class _Pickler:
write = self.write
memo = self.memo
- if name is None and self.proto >= 4:
+ if name is None:
name = getattr(obj, '__qualname__', None)
if name is None:
name = obj.__name__
- module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4)
+ module_name = whichmodule(obj, name)
try:
__import__(module_name, level=0)
module = sys.modules[module_name]
- obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4)
+ obj2, parent = _getattribute(module, name)
except (ImportError, KeyError, AttributeError):
raise PicklingError(
"Can't pickle %r: it's not found as %s.%s" %
@@ -930,11 +930,16 @@ class _Pickler:
else:
write(EXT4 + pack("<i", code))
return
+ lastname = name.rpartition('.')[2]
+ if parent is module:
+ name = lastname
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 4:
self.save(module_name)
self.save(name)
write(STACK_GLOBAL)
+ elif parent is not module:
+ self.save_reduce(getattr, (parent, lastname))
elif self.proto >= 3:
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
@@ -994,7 +999,7 @@ class _Unpickler:
meets this interface.
Optional keyword arguments are *fix_imports*, *encoding* and
- *errors*, which are used to control compatiblity support for
+ *errors*, which are used to control compatibility support for
pickle stream generated by Python 2. If *fix_imports* is True,
pickle will try to map the old Python 2 names to the new names
used in Python 3. The *encoding* and *errors* tell pickle how
@@ -1074,7 +1079,11 @@ class _Unpickler:
dispatch[FRAME[0]] = load_frame
def load_persid(self):
- pid = self.readline()[:-1].decode("ascii")
+ try:
+ pid = self.readline()[:-1].decode("ascii")
+ except UnicodeDecodeError:
+ raise UnpicklingError(
+ "persistent IDs in protocol 0 must be ASCII strings")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid
@@ -1381,8 +1390,10 @@ class _Unpickler:
elif module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
- return _getattribute(sys.modules[module], name,
- allow_qualname=self.proto >= 4)
+ if self.proto >= 4:
+ return _getattribute(sys.modules[module], name)[0]
+ else:
+ return getattr(sys.modules[module], name)
def load_reduce(self):
stack = self.stack