diff options
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 68 |
1 files changed, 58 insertions, 10 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 2b01caed9..ba35402d6 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,10 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + pickle_kwargs : dict, optional + Additional keyword arguments to pass on to pickle.load. + These are only useful when loading object arrays saved on + Python 2 when using Python 3. Parameters ---------- @@ -195,12 +199,13 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False): + def __init__(self, fid, own_fid=False, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): self.files.append(x[:-4]) @@ -256,7 +261,7 @@ class NpzFile(object): bytes.close() if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) - return format.read_array(bytes) + return format.read_array(bytes, pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) else: @@ -289,7 +294,7 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None): +def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -306,6 +311,18 @@ def load(file, mmap_mode=None): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + fix_imports : bool, optional + Only useful when loading Python 2 generated pickled files on Python 3, + which includes npy/npz files containing object arrays. If `fix_imports` + is True, pickle will try to map the old Python 2 names to the new names + used in Python 3. + encoding : str, optional + What encoding to use when reading Python 2 strings. Only useful when + loading Python 2 generated pickled files on Python 3, which includes + npy/npz files containing object arrays. Values other than 'latin1', + 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical + data. Default: 'ASCII' + Returns ------- @@ -381,6 +398,26 @@ def load(file, mmap_mode=None): else: fid = file + if encoding not in ('ASCII', 'latin1', 'bytes'): + # The 'encoding' value for pickle also affects what encoding + # the serialized binary data of Numpy arrays is loaded + # in. Pickle does not pass on the encoding information to + # Numpy. The unpickling code in numpy.core.multiarray is + # written to assume that unicode data appearing where binary + # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'. + # + # Other encoding values can corrupt binary data, and we + # purposefully disallow them. For the same reason, the errors= + # argument is not exposed, as values other than 'strict' + # result can similarly silently corrupt numerical data. + raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'") + + if sys.version_info[0] >= 3: + pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = {} + try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = asbytes('PK\x03\x04') @@ -392,17 +429,17 @@ def load(file, mmap_mode=None): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp) + return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid) + return format.read_array(fid, pickle_kwargs=pickle_kwargs) else: # Try a pickle try: - return pickle.load(fid) + return pickle.load(fid, **pickle_kwargs) except: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) @@ -411,7 +448,7 @@ def load(file, mmap_mode=None): fid.close() -def save(file, arr): +def save(file, arr, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -422,6 +459,11 @@ def save(file, arr): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + fix_imports : bool, optional + Only useful in forcing objects in object arrays on Python 3 to be pickled + in a Python 2 compatible way. If `fix_imports` is True, pickle will try to + map the new Python 3 names to the old module names used in Python 2, so that + the pickle data stream is readable with Python 2. arr : array_like Array data to be saved. @@ -458,9 +500,15 @@ def save(file, arr): else: fid = file + if sys.version_info[0] >= 3: + pickle_kwargs = dict(fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = None + try: arr = np.asanyarray(arr) - format.write_array(fid, arr) + format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -572,7 +620,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress): +def _savez(file, args, kwds, compress, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -606,7 +654,7 @@ def _savez(file, args, kwds, compress): fname = key + '.npy' fid = open(tmpfile, 'wb') try: - format.write_array(fid, np.asanyarray(val)) + format.write_array(fid, np.asanyarray(val), pickle_kwargs=pickle_kwargs) fid.close() fid = None zipf.write(tmpfile, arcname=fname) |