diff options
Diffstat (limited to 'numpy/lib/format.py')
-rw-r--r-- | numpy/lib/format.py | 44 |
1 files changed, 27 insertions, 17 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 1a2133aa9..66a1b356c 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -314,21 +314,19 @@ def _write_array_header(fp, d, version=None): header = header + ' '*topad + '\n' header = asbytes(_filter_header(header)) - if len(header) >= (256*256) and version == (1, 0): - raise ValueError("header does not fit inside %s bytes required by the" - " 1.0 format" % (256*256)) - if len(header) < (256*256): - header_len_str = struct.pack('<H', len(header)) + hlen = len(header) + if hlen < 256*256 and version in (None, (1, 0)): version = (1, 0) - elif len(header) < (2**32): - header_len_str = struct.pack('<I', len(header)) + header_prefix = magic(1, 0) + struct.pack('<H', hlen) + elif hlen < 2**32 and version in (None, (2, 0)): version = (2, 0) + header_prefix = magic(2, 0) + struct.pack('<I', hlen) else: - raise ValueError("header does not fit inside 4 GiB required by " - "the 2.0 format") + msg = "Header length %s too big for version=%s" + msg %= (hlen, version) + raise ValueError(msg) - fp.write(magic(*version)) - fp.write(header_len_str) + fp.write(header_prefix) fp.write(header) return version @@ -389,7 +387,7 @@ def read_array_header_1_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(1, 0)) + return _read_array_header(fp, version=(1, 0)) def read_array_header_2_0(fp): """ @@ -422,7 +420,7 @@ def read_array_header_2_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(2, 0)) + return _read_array_header(fp, version=(2, 0)) def _filter_header(s): @@ -517,7 +515,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None, pickle_kwargs=None): +def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +533,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + allow_pickle : bool, optional + Whether to allow writing pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass to pickle.dump, excluding 'protocol'. These are only useful when pickling objects in object @@ -543,7 +543,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): Raises ------ ValueError - If the array cannot be persisted. + If the array cannot be persisted. This includes the case of + allow_pickle=False and array being an object array. Various other errors If the array contains Python objects as part of its dtype, the process of pickling them may raise various errors if the objects @@ -565,6 +566,9 @@ def write_array(fp, array, version=None, pickle_kwargs=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. + if not allow_pickle: + raise ValueError("Object arrays cannot be saved when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} pickle.dump(array, fp, protocol=2, **pickle_kwargs) @@ -586,7 +590,7 @@ def write_array(fp, array, version=None, pickle_kwargs=None): fp.write(chunk.tobytes('C')) -def read_array(fp, pickle_kwargs=None): +def read_array(fp, allow_pickle=True, pickle_kwargs=None): """ Read an array from an NPY file. @@ -595,6 +599,8 @@ def read_array(fp, pickle_kwargs=None): fp : file_like object If this is not a real file object, then this may take extra memory and time. + allow_pickle : bool, optional + Whether to allow reading pickled data. Default: True pickle_kwargs : dict Additional keyword arguments to pass to pickle.load. These are only useful when loading object arrays saved on Python 2 when using @@ -608,7 +614,8 @@ def read_array(fp, pickle_kwargs=None): Raises ------ ValueError - If the data is invalid. + If the data is invalid, or allow_pickle=False and the file contains + an object array. """ version = read_magic(fp) @@ -622,6 +629,9 @@ def read_array(fp, pickle_kwargs=None): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. + if not allow_pickle: + raise ValueError("Object arrays cannot be loaded when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} try: |