diff options
Diffstat (limited to 'numpy/lib/format.py')
-rw-r--r-- | numpy/lib/format.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 4ff0a660f..1ff04b68a 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -517,7 +517,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None): +def write_array(fp, array, version=None, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +535,10 @@ def write_array(fp, array, version=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + pickle_kwargs : dict, optional + Additional keyword arguments to pass to pickle.dump, excluding + 'protocol'. These are only useful when pickling objects in object + arrays on Python 3 to Python 2 compatible format. Raises ------ @@ -561,7 +565,9 @@ def write_array(fp, array, version=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. - pickle.dump(array, fp, protocol=2) + if pickle_kwargs is None: + pickle_kwargs = {} + pickle.dump(array, fp, protocol=2, **pickle_kwargs) elif array.flags.f_contiguous and not array.flags.c_contiguous: if isfileobj(fp): array.T.tofile(fp) @@ -580,7 +586,7 @@ def write_array(fp, array, version=None): fp.write(chunk.tobytes('C')) -def read_array(fp): +def read_array(fp, pickle_kwargs=None): """ Read an array from an NPY file. @@ -589,6 +595,9 @@ def read_array(fp): fp : file_like object If this is not a real file object, then this may take extra memory and time. + pickle_kwargs : dict + Additional keyword arguments to pass to pickle.load. These are only + useful when loading object arrays saved on Python 2 when using Python 3. Returns ------- @@ -612,7 +621,16 @@ def read_array(fp): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. - array = pickle.load(fp) + if pickle_kwargs is None: + pickle_kwargs = {} + try: + array = pickle.load(fp, **pickle_kwargs) + except UnicodeError as err: + if sys.version_info[0] >= 3: + # Friendlier error message + raise UnicodeError("Unpickling a python object failed: %r\n" + "You may need to pass the encoding= option to numpy.load" % (err,)) + raise else: if isfileobj(fp): # We can use the fast fromfile() function. |