diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/format.py | 44 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 11 | ||||
-rw-r--r-- | numpy/lib/npyio.py | 44 | ||||
-rw-r--r-- | numpy/lib/tests/test_format.py | 60 | ||||
-rw-r--r-- | numpy/lib/tests/test_type_check.py | 2 | ||||
-rw-r--r-- | numpy/lib/type_check.py | 56 |
6 files changed, 156 insertions, 61 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 1a2133aa9..66a1b356c 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -314,21 +314,19 @@ def _write_array_header(fp, d, version=None): header = header + ' '*topad + '\n' header = asbytes(_filter_header(header)) - if len(header) >= (256*256) and version == (1, 0): - raise ValueError("header does not fit inside %s bytes required by the" - " 1.0 format" % (256*256)) - if len(header) < (256*256): - header_len_str = struct.pack('<H', len(header)) + hlen = len(header) + if hlen < 256*256 and version in (None, (1, 0)): version = (1, 0) - elif len(header) < (2**32): - header_len_str = struct.pack('<I', len(header)) + header_prefix = magic(1, 0) + struct.pack('<H', hlen) + elif hlen < 2**32 and version in (None, (2, 0)): version = (2, 0) + header_prefix = magic(2, 0) + struct.pack('<I', hlen) else: - raise ValueError("header does not fit inside 4 GiB required by " - "the 2.0 format") + msg = "Header length %s too big for version=%s" + msg %= (hlen, version) + raise ValueError(msg) - fp.write(magic(*version)) - fp.write(header_len_str) + fp.write(header_prefix) fp.write(header) return version @@ -389,7 +387,7 @@ def read_array_header_1_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(1, 0)) + return _read_array_header(fp, version=(1, 0)) def read_array_header_2_0(fp): """ @@ -422,7 +420,7 @@ def read_array_header_2_0(fp): If the data is invalid. """ - _read_array_header(fp, version=(2, 0)) + return _read_array_header(fp, version=(2, 0)) def _filter_header(s): @@ -517,7 +515,7 @@ def _read_array_header(fp, version): return d['shape'], d['fortran_order'], dtype -def write_array(fp, array, version=None, pickle_kwargs=None): +def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None): """ Write an array to an NPY file, including a header. @@ -535,6 +533,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): version : (int, int) or None, optional The version number of the format. None means use the oldest supported version that is able to store the data. Default: None + allow_pickle : bool, optional + Whether to allow writing pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass to pickle.dump, excluding 'protocol'. These are only useful when pickling objects in object @@ -543,7 +543,8 @@ def write_array(fp, array, version=None, pickle_kwargs=None): Raises ------ ValueError - If the array cannot be persisted. + If the array cannot be persisted. This includes the case of + allow_pickle=False and array being an object array. Various other errors If the array contains Python objects as part of its dtype, the process of pickling them may raise various errors if the objects @@ -565,6 +566,9 @@ def write_array(fp, array, version=None, pickle_kwargs=None): # We contain Python objects so we cannot write out the data # directly. Instead, we will pickle it out with version 2 of the # pickle protocol. + if not allow_pickle: + raise ValueError("Object arrays cannot be saved when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} pickle.dump(array, fp, protocol=2, **pickle_kwargs) @@ -586,7 +590,7 @@ def write_array(fp, array, version=None, pickle_kwargs=None): fp.write(chunk.tobytes('C')) -def read_array(fp, pickle_kwargs=None): +def read_array(fp, allow_pickle=True, pickle_kwargs=None): """ Read an array from an NPY file. @@ -595,6 +599,8 @@ def read_array(fp, pickle_kwargs=None): fp : file_like object If this is not a real file object, then this may take extra memory and time. + allow_pickle : bool, optional + Whether to allow reading pickled data. Default: True pickle_kwargs : dict Additional keyword arguments to pass to pickle.load. These are only useful when loading object arrays saved on Python 2 when using @@ -608,7 +614,8 @@ def read_array(fp, pickle_kwargs=None): Raises ------ ValueError - If the data is invalid. + If the data is invalid, or allow_pickle=False and the file contains + an object array. """ version = read_magic(fp) @@ -622,6 +629,9 @@ def read_array(fp, pickle_kwargs=None): # Now read the actual data. if dtype.hasobject: # The array contained Python objects. We need to unpickle the data. + if not allow_pickle: + raise ValueError("Object arrays cannot be loaded when " + "allow_pickle=False") if pickle_kwargs is None: pickle_kwargs = {} try: diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 9aec98cc8..d22e8c047 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -906,9 +906,9 @@ def gradient(f, *varargs, **kwargs): Returns ------- - gradient : ndarray - N arrays of the same shape as `f` giving the derivative of `f` with - respect to each dimension. + gradient : list of ndarray + Each element of `list` has the same shape as `f` giving the derivative + of `f` with respect to each dimension. Examples -------- @@ -918,6 +918,10 @@ def gradient(f, *varargs, **kwargs): >>> np.gradient(x, 2) array([ 0.5 , 0.75, 1.25, 1.75, 2.25, 2.5 ]) + For two dimensional arrays, the return will be two arrays ordered by + axis. In this example the first array stands for the gradient in + rows and the second one in columns direction: + >>> np.gradient(np.array([[1, 2, 6], [3, 4, 5]], dtype=np.float)) [array([[ 2., 2., -1.], [ 2., 2., -1.]]), array([[ 1. , 2.5, 4. ], @@ -3735,6 +3739,7 @@ def insert(arr, obj, values, axis=None): [3, 5, 3]]) Difference between sequence and scalars: + >>> np.insert(a, [1], [[1],[2],[3]], axis=1) array([[1, 1, 1], [2, 2, 2], diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index b56d7d5a9..ec89397a0 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -164,6 +164,8 @@ class NpzFile(object): f : BagObj instance An object on which attribute can be performed as an alternative to getitem access on the `NpzFile` instance itself. + allow_pickle : bool, optional + Allow loading pickled data. Default: True pickle_kwargs : dict, optional Additional keyword arguments to pass on to pickle.load. These are only useful when loading object arrays saved on @@ -199,12 +201,14 @@ class NpzFile(object): """ - def __init__(self, fid, own_fid=False, pickle_kwargs=None): + def __init__(self, fid, own_fid=False, allow_pickle=True, + pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an # optional component of the so-called standard library. _zip = zipfile_factory(fid) self._files = _zip.namelist() self.files = [] + self.allow_pickle = allow_pickle self.pickle_kwargs = pickle_kwargs for x in self._files: if x.endswith('.npy'): @@ -262,6 +266,7 @@ class NpzFile(object): if magic == format.MAGIC_PREFIX: bytes = self.zip.open(key) return format.read_array(bytes, + allow_pickle=self.allow_pickle, pickle_kwargs=self.pickle_kwargs) else: return self.zip.read(key) @@ -295,7 +300,8 @@ class NpzFile(object): return self.files.__contains__(key) -def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): +def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, + encoding='ASCII'): """ Load arrays or pickled objects from ``.npy``, ``.npz`` or pickled files. @@ -312,6 +318,12 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): and sliced like any ndarray. Memory mapping is especially useful for accessing small fragments of large files without reading the entire file into memory. + allow_pickle : bool, optional + Allow loading pickled object arrays stored in npy files. Reasons for + disallowing pickles include security, as loading pickled data can + execute arbitrary code. If pickles are disallowed, loading object + arrays will fail. + Default: True fix_imports : bool, optional Only useful when loading Python 2 generated pickled files on Python 3, which includes npy/npz files containing object arrays. If `fix_imports` @@ -324,7 +336,6 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): 'ASCII', and 'bytes' are not allowed, as they can corrupt numerical data. Default: 'ASCII' - Returns ------- result : array, tuple, dict, etc. @@ -335,6 +346,8 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): ------ IOError If the input file does not exist or cannot be read. + ValueError + The file contains an object array, but allow_pickle=False given. See Also -------- @@ -430,15 +443,20 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): # Transfer file ownership to NpzFile tmp = own_fid own_fid = False - return NpzFile(fid, own_fid=tmp, pickle_kwargs=pickle_kwargs) + return NpzFile(fid, own_fid=tmp, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) elif magic == format.MAGIC_PREFIX: # .npy file if mmap_mode: return format.open_memmap(file, mode=mmap_mode) else: - return format.read_array(fid, pickle_kwargs=pickle_kwargs) + return format.read_array(fid, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) else: # Try a pickle + if not allow_pickle: + raise ValueError("allow_pickle=False, but file does not contain " + "non-pickled data") try: return pickle.load(fid, **pickle_kwargs) except: @@ -449,7 +467,7 @@ def load(file, mmap_mode=None, fix_imports=True, encoding='ASCII'): fid.close() -def save(file, arr, fix_imports=True): +def save(file, arr, allow_pickle=True, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -460,6 +478,14 @@ def save(file, arr, fix_imports=True): then the filename is unchanged. If file is a string, a ``.npy`` extension will be appended to the file name if it does not already have one. + allow_pickle : bool, optional + Allow saving object arrays using Python pickles. Reasons for disallowing + pickles include security (loading pickled data can execute arbitrary + code) and portability (pickled objects may not be loadable on different + Python installations, for example if the stored objects require libraries + that are not available, and not all pickled data is compatible between + Python 2 and Python 3). + Default: True fix_imports : bool, optional Only useful in forcing objects in object arrays on Python 3 to be pickled in a Python 2 compatible way. If `fix_imports` is True, pickle @@ -509,7 +535,8 @@ def save(file, arr, fix_imports=True): try: arr = np.asanyarray(arr) - format.write_array(fid, arr, pickle_kwargs=pickle_kwargs) + format.write_array(fid, arr, allow_pickle=allow_pickle, + pickle_kwargs=pickle_kwargs) finally: if own_fid: fid.close() @@ -621,7 +648,7 @@ def savez_compressed(file, *args, **kwds): _savez(file, args, kwds, True) -def _savez(file, args, kwds, compress, pickle_kwargs=None): +def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None): # Import is postponed to here since zipfile depends on gzip, an optional # component of the so-called standard library. import zipfile @@ -656,6 +683,7 @@ def _savez(file, args, kwds, compress, pickle_kwargs=None): fid = open(tmpfile, 'wb') try: format.write_array(fid, np.asanyarray(val), + allow_pickle=allow_pickle, pickle_kwargs=pickle_kwargs) fid.close() fid = None diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index 169f01182..4f8a65148 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -599,6 +599,22 @@ def test_pickle_python2_python3(): encoding='latin1', fix_imports=False) +def test_pickle_disallow(): + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + path = os.path.join(data_dir, 'py2-objarr.npy') + assert_raises(ValueError, np.load, path, + allow_pickle=False, encoding='latin1') + + path = os.path.join(data_dir, 'py2-objarr.npz') + f = np.load(path, allow_pickle=False, encoding='latin1') + assert_raises(ValueError, f.__getitem__, 'x') + + path = os.path.join(tempdir, 'pickle-disabled.npy') + assert_raises(ValueError, np.save, path, np.array([None], dtype=object), + allow_pickle=False) + + def test_version_2_0(): f = BytesIO() # requires more than 2 byte for header @@ -694,6 +710,26 @@ malformed_magic = asbytes_nested([ '', ]) +def test_read_magic(): + s1 = BytesIO() + s2 = BytesIO() + + arr = np.ones((3, 6), dtype=float) + + format.write_array(s1, arr, version=(1, 0)) + format.write_array(s2, arr, version=(2, 0)) + + s1.seek(0) + s2.seek(0) + + version1 = format.read_magic(s1) + version2 = format.read_magic(s2) + + assert_(version1 == (1, 0)) + assert_(version2 == (2, 0)) + + assert_(s1.tell() == format.MAGIC_LEN) + assert_(s2.tell() == format.MAGIC_LEN) def test_read_magic_bad_magic(): for magic in malformed_magic: @@ -724,6 +760,30 @@ def test_large_header(): assert_raises(ValueError, format.write_array_header_1_0, s, d) +def test_read_array_header_1_0(): + s = BytesIO() + + arr = np.ones((3, 6), dtype=float) + format.write_array(s, arr, version=(1, 0)) + + s.seek(format.MAGIC_LEN) + shape, fortran, dtype = format.read_array_header_1_0(s) + + assert_((shape, fortran, dtype) == ((3, 6), False, float)) + + +def test_read_array_header_2_0(): + s = BytesIO() + + arr = np.ones((3, 6), dtype=float) + format.write_array(s, arr, version=(2, 0)) + + s.seek(format.MAGIC_LEN) + shape, fortran, dtype = format.read_array_header_2_0(s) + + assert_((shape, fortran, dtype) == ((3, 6), False, float)) + + def test_bad_header(): # header of length less than 2 should fail s = BytesIO() diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 3931f95e5..7afd1206c 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -277,6 +277,8 @@ class TestNanToNum(TestCase): def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) + vals = nan_to_num([1]) + assert_array_equal(vals, np.array([1], np.int)) def test_complex_good(self): vals = nan_to_num(1+1j) diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index a45d0bd86..99677b394 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -324,12 +324,13 @@ def nan_to_num(x): Returns ------- - out : ndarray, float - Array with the same shape as `x` and dtype of the element in `x` with - the greatest precision. NaN is replaced by zero, and infinity - (-infinity) is replaced by the largest (smallest or most negative) - floating point value that fits in the output dtype. All finite numbers - are upcast to the output dtype (default float64). + out : ndarray + New Array with the same shape as `x` and dtype of the element in + `x` with the greatest precision. If `x` is inexact, then NaN is + replaced by zero, and infinity (-infinity) is replaced by the + largest (smallest or most negative) floating point value that fits + in the output dtype. If `x` is not inexact, then a copy of `x` is + returned. See Also -------- @@ -354,33 +355,22 @@ def nan_to_num(x): -1.28000000e+002, 1.28000000e+002]) """ - try: - t = x.dtype.type - except AttributeError: - t = obj2sctype(type(x)) - if issubclass(t, _nx.complexfloating): - return nan_to_num(x.real) + 1j * nan_to_num(x.imag) - else: - try: - y = x.copy() - except AttributeError: - y = array(x) - if not issubclass(t, _nx.integer): - if not y.shape: - y = array([x]) - scalar = True - else: - scalar = False - are_inf = isposinf(y) - are_neg_inf = isneginf(y) - are_nan = isnan(y) - maxf, minf = _getmaxmin(y.dtype.type) - y[are_nan] = 0 - y[are_inf] = maxf - y[are_neg_inf] = minf - if scalar: - y = y[0] - return y + x = _nx.array(x, subok=True) + xtype = x.dtype.type + if not issubclass(xtype, _nx.inexact): + return x + + iscomplex = issubclass(xtype, _nx.complexfloating) + isscalar = (x.ndim == 0) + + x = x[None] if isscalar else x + dest = (x.real, x.imag) if iscomplex else (x,) + maxf, minf = _getmaxmin(x.real.dtype) + for d in dest: + _nx.copyto(d, 0.0, where=isnan(d)) + _nx.copyto(d, maxf, where=isposinf(d)) + _nx.copyto(d, minf, where=isneginf(d)) + return x[0] if isscalar else x #----------------------------------------------------------------------------- |