summaryrefslogtreecommitdiff
path: root/numpy/lib/npyio.py
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2015-03-07 00:48:12 +0200
committerPauli Virtanen <pav@iki.fi>2015-03-08 19:44:04 +0200
commit8016a68ab98969630e3c5769f38065c24a802fdc (patch)
tree4fd5abcee47443d487d3c1cb94fa94067a5f09b7 /numpy/lib/npyio.py
parentbc034dcda527372080ced5b629dc317047ef9336 (diff)
downloadnumpy-8016a68ab98969630e3c5769f38065c24a802fdc.tar.gz
BUG: enable working around pickle compatibility issues on Py3 in npy files
Add pickle compatibility flags to numpy.save and numpy.load. Allow only combinations that cannot corrupt binary data in Numpy arrays. Use the same default values as Python pickle.
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r--numpy/lib/npyio.py68
1 files changed, 58 insertions, 10 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 2b01caed9..ba35402d6 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -164,6 +164,10 @@ class NpzFile(object):
f : BagObj instance
An object on which attribute can be performed as an alternative
to getitem access on the `NpzFile` instance itself.
+ pickle_kwargs : dict, optional
+ Additional keyword arguments to pass on to pickle.load.
+ These are only useful when loading object arrays saved on
+ Python 2 when using Python 3.
Parameters
----------
@@ -195,12 +199,13 @@ class NpzFile(object):
"""
- def __init__(self, fid, own_fid=False):
+ def __init__(self, fid, own_fid=False, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an
# optional component of the so-called standard library.
_zip = zipfile_factory(fid)
self._files = _zip.namelist()
self.files = []
+ self.pickle_kwargs = pickle_kwargs
for x in self._files:
if x.endswith('.npy'):
self.files.append(x[:-4])
@@ -256,7 +261,7 @@ class NpzFile(object):
bytes.close()
if magic == format.MAGIC_PREFIX:
bytes = self.zip.open(key)
- return format.read_array(bytes)
+ return format.read_array(bytes, pickle_kwargs=self.pickle_kwargs)
else:
return self.zip.read(key)
else:
@@ -289,7 +294,7 @@ class NpzFile(object):
return self.files.__contains__(key)
-def load(file, mmap_mode=None):
+def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'):
"""
Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files.
@@ -306,6 +311,18 @@ def load(file, mmap_mode=None):
and sliced like any ndarray. Memory mapping is especially useful
for accessing small fragments of large files without reading the
entire file into memory.
+ fix_imports : bool, optional
+ Only useful when loading Python 2 generated pickled files on Python 3,
+ which includes npy/npz files containing object arrays. If `fix_imports`
+ is True, pickle will try to map the old Python 2 names to the new names
+ used in Python 3.
+ encoding : str, optional
+ What encoding to use when reading Python 2 strings. Only useful when
+ loading Python 2 generated pickled files on Python 3, which includes
+ npy/npz files containing object arrays. Values other than 'latin1',
+ 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical
+ data. Default: 'ASCII'
+
Returns
-------
@@ -381,6 +398,26 @@ def load(file, mmap_mode=None):
else:
fid = file
+ if encoding not in ('ASCII', 'latin1', 'bytes'):
+ # The 'encoding' value for pickle also affects what encoding
+ # the serialized binary data of Numpy arrays is loaded
+ # in. Pickle does not pass on the encoding information to
+ # Numpy. The unpickling code in numpy.core.multiarray is
+ # written to assume that unicode data appearing where binary
+ # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'.
+ #
+ # Other encoding values can corrupt binary data, and we
+ # purposefully disallow them. For the same reason, the errors=
+ # argument is not exposed, as values other than 'strict'
+ # result can similarly silently corrupt numerical data.
+ raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'")
+
+ if sys.version_info[0] >= 3:
+ pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports)
+ else:
+ # Nothing to do on Python 2
+ pickle_kwargs = {}
+
try:
# Code to distinguish from NumPy binary files and pickles.
_ZIP_PREFIX = asbytes('PK\x03\x04')
@@ -392,17 +429,17 @@ def load(file, mmap_mode=None):
# Transfer file ownership to NpzFile
tmp = own_fid
own_fid = False
- return NpzFile(fid, own_fid=tmp)
+ return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs)
elif magic == format.MAGIC_PREFIX:
# .npy file
if mmap_mode:
return format.open_memmap(file, mode=mmap_mode)
else:
- return format.read_array(fid)
+ return format.read_array(fid, pickle_kwargs=pickle_kwargs)
else:
# Try a pickle
try:
- return pickle.load(fid)
+ return pickle.load(fid, **pickle_kwargs)
except:
raise IOError(
"Failed to interpret file %s as a pickle" % repr(file))
@@ -411,7 +448,7 @@ def load(file, mmap_mode=None):
fid.close()
-def save(file, arr):
+def save(file, arr, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
@@ -422,6 +459,11 @@ def save(file, arr):
then the filename is unchanged. If file is a string, a ``.npy``
extension will be appended to the file name if it does not already
have one.
+ fix_imports : bool, optional
+ Only useful in forcing objects in object arrays on Python 3 to be pickled
+ in a Python 2 compatible way. If `fix_imports` is True, pickle will try to
+ map the new Python 3 names to the old module names used in Python 2, so that
+ the pickle data stream is readable with Python 2.
arr : array_like
Array data to be saved.
@@ -458,9 +500,15 @@ def save(file, arr):
else:
fid = file
+ if sys.version_info[0] >= 3:
+ pickle_kwargs = dict(fix_imports=fix_imports)
+ else:
+ # Nothing to do on Python 2
+ pickle_kwargs = None
+
try:
arr = np.asanyarray(arr)
- format.write_array(fid, arr)
+ format.write_array(fid, arr, pickle_kwargs=pickle_kwargs)
finally:
if own_fid:
fid.close()
@@ -572,7 +620,7 @@ def savez_compressed(file, *args, **kwds):
_savez(file, args, kwds, True)
-def _savez(file, args, kwds, compress):
+def _savez(file, args, kwds, compress, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an optional
# component of the so-called standard library.
import zipfile
@@ -606,7 +654,7 @@ def _savez(file, args, kwds, compress):
fname = key + '.npy'
fid = open(tmpfile, 'wb')
try:
- format.write_array(fid, np.asanyarray(val))
+ format.write_array(fid, np.asanyarray(val), pickle_kwargs=pickle_kwargs)
fid.close()
fid = None
zipf.write(tmpfile, arcname=fname)