diff options
Diffstat (limited to 'numpy/lib/format.py')
-rw-r--r-- | numpy/lib/format.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index a0405b310..66a1b356c 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -515,7 +515,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None, pickle_kwargs=None): +def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -533,6 +533,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + allow_pickle : bool, optional + Whether to allow writing pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass to pickle.dump, excluding 'protocol'. These are only useful when pickling objects in object @@ -541,7 +543,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): Raises ------ ValueError - If the array cannot be persisted. + If the array cannot be persisted. This includes the case of + allow_pickle=False and array being an object array. Various other errors If the array contains Python objects as part of its dtype, the process of pickling them may raise various errors if the objects @@ -563,6 +566,9 @@ def write_array(fp, array, version=None, pickle_kwargs=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. + if not allow_pickle: + raise ValueError("Object arrays cannot be saved when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} pickle.dump(array, fp, protocol=2, **pickle_kwargs) @@ -584,7 +590,7 @@ def write_array(fp, array, version=None, pickle_kwargs=None): fp.write(chunk.tobytes('C')) -def read_array(fp, pickle_kwargs=None): +def read_array(fp, allow_pickle=True, pickle_kwargs=None): """ Read an array from an NPY file. @@ -593,6 +599,8 @@ def read_array(fp, pickle_kwargs=None): fp : file_like object If this is not a real file object, then this may take extra memory and time. + allow_pickle : bool, optional + Whether to allow reading pickled data. Default: True pickle_kwargs : dict Additional keyword arguments to pass to pickle.load. These are only useful when loading object arrays saved on Python 2 when using @@ -606,7 +614,8 @@ def read_array(fp, pickle_kwargs=None): Raises ------ ValueError - If the data is invalid. + If the data is invalid, or allow_pickle=False and the file contains + an object array. """ version = read_magic(fp) @@ -620,6 +629,9 @@ def read_array(fp, pickle_kwargs=None): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. + if not allow_pickle: + raise ValueError("Object arrays cannot be loaded when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} try: |