diff options
author | Pauli Virtanen <pav@iki.fi> | 2015-03-07 00:48:12 +0200 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2015-03-08 19:44:04 +0200 |
commit | 8016a68ab98969630e3c5769f38065c24a802fdc (patch) | |
tree | 4fd5abcee47443d487d3c1cb94fa94067a5f09b7 /numpy/lib/npyio.py | |
parent | bc034dcda527372080ced5b629dc317047ef9336 (diff) | |
download | numpy-8016a68ab98969630e3c5769f38065c24a802fdc.tar.gz |
BUG: enable working around pickle compatibility issues on Py3 in npy files
Add pickle compatibility flags to numpy.save and numpy.load. Allow only
combinations that cannot corrupt binary data in Numpy arrays. Use the
same default values as Python pickle.
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 68 |
1 files changed, 58 insertions, 10 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 2b01caed9..ba35402d6 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,10 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + pickle_kwargs : dict, optional + Additional keyword arguments to pass on to pickle.load. + These are only useful when loading object arrays saved on + Python 2 when using Python 3. Parameters ---------- @@ -195,12 +199,13 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False): + def __init__(self, fid, own_fid=False, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): self.files.append(x[:-4]) @@ -256,7 +261,7 @@ class NpzFile(object): bytes.close() if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) - return format.read_array(bytes) + return format.read_array(bytes, pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) else: @@ -289,7 +294,7 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None): +def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -306,6 +311,18 @@ def load(file, mmap_mode=None): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + fix_imports : bool, optional + Only useful when loading Python 2 generated pickled files on Python 3, + which includes npy/npz files containing object arrays. If `fix_imports` + is True, pickle will try to map the old Python 2 names to the new names + used in Python 3. + encoding : str, optional + What encoding to use when reading Python 2 strings. Only useful when + loading Python 2 generated pickled files on Python 3, which includes + npy/npz files containing object arrays. Values other than 'latin1', + 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical + data. Default: 'ASCII' + Returns ------- @@ -381,6 +398,26 @@ def load(file, mmap_mode=None): else: fid = file + if encoding not in ('ASCII', 'latin1', 'bytes'): + # The 'encoding' value for pickle also affects what encoding + # the serialized binary data of Numpy arrays is loaded + # in. Pickle does not pass on the encoding information to + # Numpy. The unpickling code in numpy.core.multiarray is + # written to assume that unicode data appearing where binary + # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'. + # + # Other encoding values can corrupt binary data, and we + # purposefully disallow them. For the same reason, the errors= + # argument is not exposed, as values other than 'strict' + # result can similarly silently corrupt numerical data. + raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'") + + if sys.version_info[0] >= 3: + pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = {} + try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = asbytes('PK\x03\x04') @@ -392,17 +429,17 @@ def load(file, mmap_mode=None): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp) + return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid) + return format.read_array(fid, pickle_kwargs=pickle_kwargs) else: # Try a pickle try: - return pickle.load(fid) + return pickle.load(fid, **pickle_kwargs) except: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) @@ -411,7 +448,7 @@ def load(file, mmap_mode=None): fid.close() -def save(file, arr): +def save(file, arr, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -422,6 +459,11 @@ def save(file, arr): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + fix_imports : bool, optional + Only useful in forcing objects in object arrays on Python 3 to be pickled + in a Python 2 compatible way. If `fix_imports` is True, pickle will try to + map the new Python 3 names to the old module names used in Python 2, so that + the pickle data stream is readable with Python 2. arr : array_like Array data to be saved. @@ -458,9 +500,15 @@ def save(file, arr): else: fid = file + if sys.version_info[0] >= 3: + pickle_kwargs = dict(fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = None + try: arr = np.asanyarray(arr) - format.write_array(fid, arr) + format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -572,7 +620,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress): +def _savez(file, args, kwds, compress, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -606,7 +654,7 @@ def _savez(file, args, kwds, compress): fname = key + '.npy' fid = open(tmpfile, 'wb') try: - format.write_array(fid, np.asanyarray(val)) + format.write_array(fid, np.asanyarray(val), pickle_kwargs=pickle_kwargs) fid.close() fid = None zipf.write(tmpfile, arcname=fname) |