diff options
author | Pauli Virtanen <pav@iki.fi> | 2015-03-07 00:48:12 +0200 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2015-03-08 19:44:04 +0200 |
commit | 8016a68ab98969630e3c5769f38065c24a802fdc (patch) | |
tree | 4fd5abcee47443d487d3c1cb94fa94067a5f09b7 /numpy/lib | |
parent | bc034dcda527372080ced5b629dc317047ef9336 (diff) | |
download | numpy-8016a68ab98969630e3c5769f38065c24a802fdc.tar.gz |
BUG: enable working around pickle compatibility issues on Py3 in npy files
Add pickle compatibility flags to numpy.save and numpy.load. Allow only
combinations that cannot corrupt binary data in Numpy arrays. Use the
same default values as Python pickle.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/format.py | 26 | ||||
-rw-r--r-- | numpy/lib/npyio.py | 68 | ||||
-rw-r--r-- | numpy/lib/tests/data/py2-objarr.npy | bin | 0 -> 258 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py2-objarr.npz | bin | 0 -> 366 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py3-objarr.npy | bin | 0 -> 341 bytes | |||
-rw-r--r-- | numpy/lib/tests/data/py3-objarr.npz | bin | 0 -> 449 bytes | |||
-rw-r--r-- | numpy/lib/tests/test_format.py | 67 |
7 files changed, 146 insertions, 15 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 4ff0a660f..1ff04b68a 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -517,7 +517,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None): +def write_array(fp, array, version=None, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +535,10 @@ def write_array(fp, array, version=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + pickle_kwargs : dict, optional + Additional keyword arguments to pass to pickle.dump, excluding + 'protocol'. These are only useful when pickling objects in object + arrays on Python 3 to Python 2 compatible format. Raises ------ @@ -561,7 +565,9 @@ def write_array(fp, array, version=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. - pickle.dump(array, fp, protocol=2) + if pickle_kwargs is None: + pickle_kwargs = {} + pickle.dump(array, fp, protocol=2, **pickle_kwargs) elif array.flags.f_contiguous and not array.flags.c_contiguous: if isfileobj(fp): array.T.tofile(fp) @@ -580,7 +586,7 @@ def write_array(fp, array, version=None): fp.write(chunk.tobytes('C')) -def read_array(fp): +def read_array(fp, pickle_kwargs=None): """ Read an array from an NPY file. @@ -589,6 +595,9 @@ def read_array(fp): fp : file_like object If this is not a real file object, then this may take extra memory and time. + pickle_kwargs : dict + Additional keyword arguments to pass to pickle.load. These are only + useful when loading object arrays saved on Python 2 when using Python 3. Returns ------- @@ -612,7 +621,16 @@ def read_array(fp): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. - array = pickle.load(fp) + if pickle_kwargs is None: + pickle_kwargs = {} + try: + array = pickle.load(fp, **pickle_kwargs) + except UnicodeError as err: + if sys.version_info[0] >= 3: + # Friendlier error message + raise UnicodeError("Unpickling a python object failed: %r\n" + "You may need to pass the encoding= option to numpy.load" % (err,)) + raise else: if isfileobj(fp): # We can use the fast fromfile() function. diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 2b01caed9..ba35402d6 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,10 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + pickle_kwargs : dict, optional + Additional keyword arguments to pass on to pickle.load. + These are only useful when loading object arrays saved on + Python 2 when using Python 3. Parameters ---------- @@ -195,12 +199,13 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False): + def __init__(self, fid, own_fid=False, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): self.files.append(x[:-4]) @@ -256,7 +261,7 @@ class NpzFile(object): bytes.close() if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) - return format.read_array(bytes) + return format.read_array(bytes, pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) else: @@ -289,7 +294,7 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None): +def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -306,6 +311,18 @@ def load(file, mmap_mode=None): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + fix_imports : bool, optional + Only useful when loading Python 2 generated pickled files on Python 3, + which includes npy/npz files containing object arrays. If `fix_imports` + is True, pickle will try to map the old Python 2 names to the new names + used in Python 3. + encoding : str, optional + What encoding to use when reading Python 2 strings. Only useful when + loading Python 2 generated pickled files on Python 3, which includes + npy/npz files containing object arrays. Values other than 'latin1', + 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical + data. Default: 'ASCII' + Returns ------- @@ -381,6 +398,26 @@ def load(file, mmap_mode=None): else: fid = file + if encoding not in ('ASCII', 'latin1', 'bytes'): + # The 'encoding' value for pickle also affects what encoding + # the serialized binary data of Numpy arrays is loaded + # in. Pickle does not pass on the encoding information to + # Numpy. The unpickling code in numpy.core.multiarray is + # written to assume that unicode data appearing where binary + # should be is in 'latin1'. 'bytes' is also safe, as is 'ASCII'. + # + # Other encoding values can corrupt binary data, and we + # purposefully disallow them. For the same reason, the errors= + # argument is not exposed, as values other than 'strict' + # result can similarly silently corrupt numerical data. + raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'") + + if sys.version_info[0] >= 3: + pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = {} + try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = asbytes('PK\x03\x04') @@ -392,17 +429,17 @@ def load(file, mmap_mode=None): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp) + return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid) + return format.read_array(fid, pickle_kwargs=pickle_kwargs) else: # Try a pickle try: - return pickle.load(fid) + return pickle.load(fid, **pickle_kwargs) except: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) @@ -411,7 +448,7 @@ def load(file, mmap_mode=None): fid.close() -def save(file, arr): +def save(file, arr, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -422,6 +459,11 @@ def save(file, arr): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + fix_imports : bool, optional + Only useful in forcing objects in object arrays on Python 3 to be pickled + in a Python 2 compatible way. If `fix_imports` is True, pickle will try to + map the new Python 3 names to the old module names used in Python 2, so that + the pickle data stream is readable with Python 2. arr : array_like Array data to be saved. @@ -458,9 +500,15 @@ def save(file, arr): else: fid = file + if sys.version_info[0] >= 3: + pickle_kwargs = dict(fix_imports=fix_imports) + else: + # Nothing to do on Python 2 + pickle_kwargs = None + try: arr = np.asanyarray(arr) - format.write_array(fid, arr) + format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -572,7 +620,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress): +def _savez(file, args, kwds, compress, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -606,7 +654,7 @@ def _savez(file, args, kwds, compress): fname = key + '.npy' fid = open(tmpfile, 'wb') try: - format.write_array(fid, np.asanyarray(val)) + format.write_array(fid, np.asanyarray(val), pickle_kwargs=pickle_kwargs) fid.close() fid = None zipf.write(tmpfile, arcname=fname) diff --git a/numpy/lib/tests/data/py2-objarr.npy b/numpy/lib/tests/data/py2-objarr.npy Binary files differnew file mode 100644 index 000000000..12936c92d --- /dev/null +++ b/numpy/lib/tests/data/py2-objarr.npy diff --git a/numpy/lib/tests/data/py2-objarr.npz b/numpy/lib/tests/data/py2-objarr.npz Binary files differnew file mode 100644 index 000000000..68a3b53a1 --- /dev/null +++ b/numpy/lib/tests/data/py2-objarr.npz diff --git a/numpy/lib/tests/data/py3-objarr.npy b/numpy/lib/tests/data/py3-objarr.npy Binary files differnew file mode 100644 index 000000000..6776074b4 --- /dev/null +++ b/numpy/lib/tests/data/py3-objarr.npy diff --git a/numpy/lib/tests/data/py3-objarr.npz b/numpy/lib/tests/data/py3-objarr.npz Binary files differnew file mode 100644 index 000000000..05eac0b76 --- /dev/null +++ b/numpy/lib/tests/data/py3-objarr.npz diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index ee77386bc..d3d22c127 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -284,7 +284,7 @@ import warnings from io import BytesIO import numpy as np -from numpy.compat import asbytes, asbytes_nested +from numpy.compat import asbytes, asbytes_nested, sixu from numpy.testing import ( run_module_suite, assert_, assert_array_equal, assert_raises, raises, dec @@ -534,6 +534,71 @@ def test_python2_python3_interoperability(): assert_array_equal(data, np.ones(2)) +def test_pickle_python2_python3(): + # Test that loading object arrays saved on Python 2 works both on + # Python 2 and Python 3 and vice versa + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + if sys.version_info[0] >= 3: + xrange = range + else: + import __builtin__ + xrange = __builtin__.xrange + + expected = np.array([None, xrange, sixu('\u512a\u826f'), + asbytes('\xe4\xb8\x8d\xe8\x89\xaf')], + dtype=object) + + for fname in ['py2-objarr.npy', 'py2-objarr.npz', + 'py3-objarr.npy', 'py3-objarr.npz']: + path = os.path.join(data_dir, fname) + + if (fname.endswith('.npz') and sys.version_info[0] == 2 and + sys.version_info[1] < 7): + # Reading object arrays directly from zipfile appears to fail + # on Py2.6, see cfae0143b4 + continue + + for encoding in ['bytes', 'latin1']: + if (sys.version_info[0] >= 3 and sys.version_info[1] < 4 and + encoding == 'bytes'): + # The bytes encoding is available starting from Python 3.4 + continue + + data_f = np.load(path, encoding=encoding) + if fname.endswith('.npz'): + data = data_f['x'] + data_f.close() + else: + data = data_f + + if sys.version_info[0] >= 3: + if encoding == 'latin1' and fname.startswith('py2'): + assert_(isinstance(data[3], str)) + assert_array_equal(data[:-1], expected[:-1]) + # mojibake occurs + assert_array_equal(data[-1].encode(encoding), expected[-1]) + else: + assert_(isinstance(data[3], bytes)) + assert_array_equal(data, expected) + else: + assert_array_equal(data, expected) + + if sys.version_info[0] >= 3: + if fname.startswith('py2'): + if fname.endswith('.npz'): + data = np.load(path) + assert_raises(UnicodeError, data.__getitem__, 'x') + data.close() + data = np.load(path, fix_imports=False, encoding='latin1') + assert_raises(ImportError, data.__getitem__, 'x') + data.close() + else: + assert_raises(UnicodeError, np.load, path) + assert_raises(ImportError, np.load, path, encoding='latin1', + fix_imports=False) + + def test_version_2_0(): f = BytesIO() # requires more than 2 byte for header |