diff options
author | Jaime <jaime.frio@gmail.com> | 2015-06-05 08:43:34 -0700 |
---|---|---|
committer | Jaime <jaime.frio@gmail.com> | 2015-06-05 08:43:34 -0700 |
commit | a99c77577850054c55850102132dcf3eb43dea6a (patch) | |
tree | 8e0fe904c161e1122f0b6d3a3012fc84148354df | |
parent | 30d755d8737505717d54ed32501261bb94130a7f (diff) | |
parent | e2cd6fa869cec6d92062fb687d8e6952c1202017 (diff) | |
download | numpy-a99c77577850054c55850102132dcf3eb43dea6a.tar.gz |
Merge pull request #5548 from ahaldane/view_objects_owned
ENH: structured datatype safety checks
-rw-r--r-- | numpy/core/_internal.py | 168 | ||||
-rw-r--r-- | numpy/core/src/multiarray/getset.c | 19 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 43 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 96 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 5 | ||||
-rw-r--r-- | numpy/lib/tests/test_recfunctions.py | 30 |
7 files changed, 332 insertions, 38 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index d32f59390..e80c22dfe 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -305,6 +305,174 @@ def _index_fields(ary, fields): copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']} return array(view, dtype=copy_dtype, copy=True) +def _get_all_field_offsets(dtype, base_offset=0): + """ Returns the types and offsets of all fields in a (possibly structured) + data type, including nested fields and subarrays. + + Parameters + ---------- + dtype : data-type + Data type to extract fields from. + base_offset : int, optional + Additional offset to add to all field offsets. + + Returns + ------- + fields : list of (data-type, int) pairs + A flat list of (dtype, byte offset) pairs. + + """ + fields = [] + if dtype.fields is not None: + for name in dtype.names: + sub_dtype = dtype.fields[name][0] + sub_offset = dtype.fields[name][1] + base_offset + fields.extend(_get_all_field_offsets(sub_dtype, sub_offset)) + else: + if dtype.shape: + sub_offsets = _get_all_field_offsets(dtype.base, base_offset) + count = 1 + for dim in dtype.shape: + count *= dim + fields.extend((typ, off + dtype.base.itemsize*j) + for j in range(count) for (typ, off) in sub_offsets) + else: + fields.append((dtype, base_offset)) + return fields + +def _check_field_overlap(new_fields, old_fields): + """ Perform object memory overlap tests for two data-types (see + _view_is_safe). + + This function checks that new fields only access memory contained in old + fields, and that non-object fields are not interpreted as objects and vice + versa. + + Parameters + ---------- + new_fields : list of (data-type, int) pairs + Flat list of (dtype, byte offset) pairs for the new data type, as + returned by _get_all_field_offsets. + old_fields: list of (data-type, int) pairs + Flat list of (dtype, byte offset) pairs for the old data type, as + returned by _get_all_field_offsets. + + Raises + ------ + TypeError + If the new fields are incompatible with the old fields + + """ + from .numerictypes import object_ + from .multiarray import dtype + + #first go byte by byte and check we do not access bytes not in old_fields + new_bytes = set() + for tp, off in new_fields: + new_bytes.update(set(range(off, off+tp.itemsize))) + old_bytes = set() + for tp, off in old_fields: + old_bytes.update(set(range(off, off+tp.itemsize))) + if new_bytes.difference(old_bytes): + raise TypeError("view would access data parent array doesn't own") + + #next check that we do not interpret non-Objects as Objects, and vv + obj_offsets = [off for (tp, off) in old_fields if tp.type is object_] + obj_size = dtype(object_).itemsize + + for fld_dtype, fld_offset in new_fields: + if fld_dtype.type is object_: + # check we do not create object views where + # there are no objects. + if fld_offset not in obj_offsets: + raise TypeError("cannot view non-Object data as Object type") + else: + # next check we do not create non-object views + # where there are already objects. + # see validate_object_field_overlap for a similar computation. + for obj_offset in obj_offsets: + if (fld_offset < obj_offset + obj_size and + obj_offset < fld_offset + fld_dtype.itemsize): + raise TypeError("cannot view Object as non-Object type") + +def _getfield_is_safe(oldtype, newtype, offset): + """ Checks safety of getfield for object arrays. + + As in _view_is_safe, we need to check that memory containing objects is not + reinterpreted as a non-object datatype and vice versa. + + Parameters + ---------- + oldtype : data-type + Data type of the original ndarray. + newtype : data-type + Data type of the field being accessed by ndarray.getfield + offset : int + Offset of the field being accessed by ndarray.getfield + + Raises + ------ + TypeError + If the field access is invalid + + """ + new_fields = _get_all_field_offsets(newtype, offset) + old_fields = _get_all_field_offsets(oldtype) + # raises if there is a problem + _check_field_overlap(new_fields, old_fields) + +def _view_is_safe(oldtype, newtype): + """ Checks safety of a view involving object arrays, for example when + doing:: + + np.zeros(10, dtype=oldtype).view(newtype) + + We need to check that + 1) No memory that is not an object will be interpreted as a object, + 2) No memory containing an object will be interpreted as an arbitrary type. + Both cases can cause segfaults, eg in the case the view is written to. + Strategy here is to also disallow views where newtype has any field in a + place oldtype doesn't. + + Parameters + ---------- + oldtype : data-type + Data type of original ndarray + newtype : data-type + Data type of the view + + Raises + ------ + TypeError + If the new type is incompatible with the old type. + + """ + new_fields = _get_all_field_offsets(newtype) + new_size = newtype.itemsize + + old_fields = _get_all_field_offsets(oldtype) + old_size = oldtype.itemsize + + # if the itemsizes are not equal, we need to check that all the + # 'tiled positions' of the object match up. Here, we allow + # for arbirary itemsizes (even those possibly disallowed + # due to stride/data length issues). + if old_size == new_size: + new_num = old_num = 1 + else: + gcd_new_old = _gcd(new_size, old_size) + new_num = old_size // gcd_new_old + old_num = new_size // gcd_new_old + + # get position of fields within the tiling + new_fieldtile = [(tp, off + new_size*j) + for j in range(new_num) for (tp, off) in new_fields] + old_fieldtile = [(tp, off + old_size*j) + for j in range(old_num) for (tp, off) in old_fields] + + # raises if there is a problem + _check_field_overlap(new_fieldtile, old_fieldtile) + # Given a string containing a PEP 3118 format specifier, # construct a Numpy dtype diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c index cbc798273..9ba12b092 100644 --- a/numpy/core/src/multiarray/getset.c +++ b/numpy/core/src/multiarray/getset.c @@ -9,8 +9,8 @@ #include "numpy/arrayobject.h" #include "npy_config.h" - #include "npy_pycompat.h" +#include "npy_import.h" #include "common.h" #include "scalartypes.h" @@ -434,6 +434,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg) npy_intp newdim; int i; char *msg = "new type not compatible with array."; + PyObject *safe; + static PyObject *checkfunc = NULL; + + npy_cache_pyfunc("numpy.core._internal", "_view_is_safe", &checkfunc); + if (checkfunc == NULL) { + return -1; + } if (arg == NULL) { PyErr_SetString(PyExc_AttributeError, @@ -448,15 +455,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg) return -1; } - if (PyDataType_FLAGCHK(newtype, NPY_ITEM_HASOBJECT) || - PyDataType_FLAGCHK(newtype, NPY_ITEM_IS_POINTER) || - PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_HASOBJECT) || - PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_IS_POINTER)) { - PyErr_SetString(PyExc_TypeError, - "Cannot change data-type for object array."); + /* check that we are not reinterpreting memory containing Objects */ + safe = PyObject_CallFunction(checkfunc, "OO", PyArray_DESCR(self), newtype); + if (safe == NULL) { Py_DECREF(newtype); return -1; } + Py_DECREF(safe); if (newtype->elsize == 0) { /* Allow a void view */ diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index a51453dd1..d06e4a512 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -10,6 +10,7 @@ #include "npy_config.h" #include "npy_pycompat.h" +#include "npy_import.h" #include "ufunc_override.h" #include "common.h" #include "ctors.h" @@ -358,15 +359,23 @@ NPY_NO_EXPORT PyObject * PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset) { PyObject *ret = NULL; + PyObject *safe; + static PyObject *checkfunc = NULL; - if (offset < 0 || (offset + typed->elsize) > PyArray_DESCR(self)->elsize) { - PyErr_Format(PyExc_ValueError, - "Need 0 <= offset <= %d for requested type " - "but received offset = %d", - PyArray_DESCR(self)->elsize-typed->elsize, offset); - Py_DECREF(typed); + npy_cache_pyfunc("numpy.core._internal", "_getfield_is_safe", &checkfunc); + if (checkfunc == NULL) { return NULL; } + + /* check that we are not reinterpreting memory containing Objects */ + /* only returns True or raises */ + safe = PyObject_CallFunction(checkfunc, "OOi", PyArray_DESCR(self), + typed, offset); + if (safe == NULL) { + return NULL; + } + Py_DECREF(safe); + ret = PyArray_NewFromDescr(Py_TYPE(self), typed, PyArray_NDIM(self), PyArray_DIMS(self), @@ -417,23 +426,12 @@ PyArray_SetField(PyArrayObject *self, PyArray_Descr *dtype, PyObject *ret = NULL; int retval = 0; - if (offset < 0 || (offset + dtype->elsize) > PyArray_DESCR(self)->elsize) { - PyErr_Format(PyExc_ValueError, - "Need 0 <= offset <= %d for requested type " - "but received offset = %d", - PyArray_DESCR(self)->elsize-dtype->elsize, offset); - Py_DECREF(dtype); - return -1; - } - ret = PyArray_NewFromDescr(Py_TYPE(self), - dtype, PyArray_NDIM(self), PyArray_DIMS(self), - PyArray_STRIDES(self), PyArray_BYTES(self) + offset, - PyArray_FLAGS(self), (PyObject *)self); + /* getfield returns a view we can write to */ + ret = PyArray_GetField(self, dtype, offset); if (ret == NULL) { return -1; } - PyArray_UpdateFlags((PyArrayObject *)ret, NPY_ARRAY_UPDATE_ALL); retval = PyArray_CopyObject((PyArrayObject *)ret, val); Py_DECREF(ret); return retval; @@ -455,13 +453,6 @@ array_setfield(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } - if (PyDataType_REFCHK(PyArray_DESCR(self))) { - PyErr_SetString(PyExc_RuntimeError, - "cannot call setfield on an object array"); - Py_DECREF(dtype); - return NULL; - } - if (PyArray_SetField(self, dtype, offset, value) < 0) { return NULL; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index f37668de9..74c57e18a 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -809,6 +809,14 @@ class TestStructured(TestCase): t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')] assert_(not np.can_cast(a.dtype, t, casting=casting)) + def test_objview(self): + # https://github.com/numpy/numpy/issues/3286 + a = np.array([], dtype=[('a', 'f'), ('b', 'f'), ('c', 'O')]) + a[['a', 'b']] # TypeError? + + # https://github.com/numpy/numpy/issues/3253 + dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')]) + new2 = dat2[['B', 'A']] # TypeError? class TestBool(TestCase): def test_test_interning(self): @@ -3577,8 +3585,9 @@ class TestRecord(TestCase): class TestView(TestCase): def test_basic(self): - x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], dtype=[('r', np.int8), ('g', np.int8), - ('b', np.int8), ('a', np.int8)]) + x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], + dtype=[('r', np.int8), ('g', np.int8), + ('b', np.int8), ('a', np.int8)]) # We must be specific about the endianness here: y = x.view(dtype='<i4') # ... and again without the keyword. @@ -5523,6 +5532,89 @@ class TestHashing(TestCase): x = np.array([]) self.assertFalse(isinstance(x, collections.Hashable)) +from numpy.core._internal import _view_is_safe + +class TestObjViewSafetyFuncs: + def test_view_safety(self): + psize = dtype('p').itemsize + + # creates dtype but with extra character code - for missing 'p' fields + def mtype(s): + n, offset, fields = 0, 0, [] + for c in s.split(','): #subarrays won't work + if c != '-': + fields.append(('f{0}'.format(n), c, offset)) + n += 1 + offset += dtype(c).itemsize if c != '-' else psize + + names, formats, offsets = zip(*fields) + return dtype({'names': names, 'formats': formats, + 'offsets': offsets, 'itemsize': offset}) + + # test nonequal itemsizes with objects: + # these should succeed: + _view_is_safe(dtype('O,p,O,p'), dtype('O,p,O,p,O,p')) + _view_is_safe(dtype('O,O'), dtype('O,O,O')) + # these should fail: + assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,O')) + assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,p')) + assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('p,O')) + + # test nonequal itemsizes with missing fields: + # these should succeed: + _view_is_safe(mtype('-,p,-,p'), mtype('-,p,-,p,-,p')) + _view_is_safe(dtype('p,p'), dtype('p,p,p')) + # these should fail: + assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,p')) + assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,-')) + assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('-,p')) + + # scans through positions at which we can view a type + def scanView(d1, otype): + goodpos = [] + for shift in range(d1.itemsize - dtype(otype).itemsize+1): + d2 = dtype({'names': ['f0'], 'formats': [otype], + 'offsets': [shift], 'itemsize': d1.itemsize}) + try: + _view_is_safe(d1, d2) + except TypeError: + pass + else: + goodpos.append(shift) + return goodpos + + # test partial overlap with object field + assert_equal(scanView(dtype('p,O,p,p,O,O'), 'p'), + [0] + list(range(2*psize, 3*psize+1))) + assert_equal(scanView(dtype('p,O,p,p,O,O'), 'O'), + [psize, 4*psize, 5*psize]) + + # test partial overlap with missing field + assert_equal(scanView(mtype('p,-,p,p,-,-'), 'p'), + [0] + list(range(2*psize, 3*psize+1))) + + # test nested structures with objects: + nestedO = dtype([('f0', 'p'), ('f1', 'p,O,p')]) + assert_equal(scanView(nestedO, 'p'), list(range(psize+1)) + [3*psize]) + assert_equal(scanView(nestedO, 'O'), [2*psize]) + + # test nested structures with missing fields: + nestedM = dtype([('f0', 'p'), ('f1', mtype('p,-,p'))]) + assert_equal(scanView(nestedM, 'p'), list(range(psize+1)) + [3*psize]) + + # test subarrays with objects + subarrayO = dtype('p,(2,3)O,p') + assert_equal(scanView(subarrayO, 'p'), [0, 7*psize]) + assert_equal(scanView(subarrayO, 'O'), + list(range(psize, 6*psize+1, psize))) + + #test dtype with overlapping fields + overlapped = dtype({'names': ['f0', 'f1', 'f2', 'f3'], + 'formats': ['p', 'p', 'p', 'p'], + 'offsets': [0, 1, 3*psize-1, 3*psize], + 'itemsize': 4*psize}) + assert_equal(scanView(overlapped, 'p'), [0, 1, 3*psize-1, 3*psize]) + if __name__ == "__main__": run_module_suite() diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index 2f86b7ccb..b07b3c876 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -219,6 +219,15 @@ class TestRecord(TestCase): assert_equal(a, pickle.loads(pickle.dumps(a))) assert_equal(a[0], pickle.loads(pickle.dumps(a[0]))) + def test_objview_record(self): + # https://github.com/numpy/numpy/issues/2599 + dt = np.dtype([('foo', 'i8'), ('bar', 'O')]) + r = np.zeros((1,3), dtype=dt).view(np.recarray) + r.foo = np.array([1, 2, 3]) # TypeError? + + # https://github.com/numpy/numpy/issues/3256 + ra = np.recarray((2,), dtype=[('x', object), ('y', float), ('z', int)]) + ra[['x','y']] #TypeError? def test_find_duplicate(): l1 = [1, 2, 3, 4, 5, 6] diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 9e8511a01..399470214 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1739,10 +1739,9 @@ class TestRegression(TestCase): assert_equal(np.add.accumulate(x[:-1, 0]), []) def test_objectarray_setfield(self): - # Setfield directly manipulates the raw array data, - # so is invalid for object arrays. + # Setfield should not overwrite Object fields with non-Object data x = np.array([1, 2, 3], dtype=object) - assert_raises(RuntimeError, x.setfield, 4, np.int32, 0) + assert_raises(TypeError, x.setfield, 4, np.int32, 0) def test_setting_rank0_string(self): "Ticket #1736" diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py index 51a2077eb..13e75cbd0 100644 --- a/numpy/lib/tests/test_recfunctions.py +++ b/numpy/lib/tests/test_recfunctions.py @@ -700,6 +700,36 @@ class TestJoinBy2(TestCase): assert_equal(test.dtype, control.dtype) assert_equal(test, control) +class TestAppendFieldsObj(TestCase): + """ + Test append_fields with arrays containing objects + """ + # https://github.com/numpy/numpy/issues/2346 + + def setUp(self): + from datetime import date + self.data = dict(obj=date(2000, 1, 1)) + + def test_append_to_objects(self): + "Test append_fields when the base array contains objects" + obj = self.data['obj'] + x = np.array([(obj, 1.), (obj, 2.)], + dtype=[('A', object), ('B', float)]) + y = np.array([10, 20], dtype=int) + test = append_fields(x, 'C', data=y, usemask=False) + control = np.array([(obj, 1.0, 10), (obj, 2.0, 20)], + dtype=[('A', object), ('B', float), ('C', int)]) + assert_equal(test, control) + + def test_append_with_objects(self): + "Test append_fields when the appended data contains objects" + obj = self.data['obj'] + x = np.array([(10, 1.), (20, 2.)], dtype=[('A', int), ('B', float)]) + y = np.array([obj, obj], dtype=object) + test = append_fields(x, 'C', data=y, dtypes=object, usemask=False) + control = np.array([(10, 1.0, obj), (20, 2.0, obj)], + dtype=[('A', int), ('B', float), ('C', object)]) + assert_equal(test, control) if __name__ == '__main__': run_module_suite() |