summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime <jaime.frio@gmail.com>2015-06-05 08:43:34 -0700
committerJaime <jaime.frio@gmail.com>2015-06-05 08:43:34 -0700
commita99c77577850054c55850102132dcf3eb43dea6a (patch)
tree8e0fe904c161e1122f0b6d3a3012fc84148354df
parent30d755d8737505717d54ed32501261bb94130a7f (diff)
parente2cd6fa869cec6d92062fb687d8e6952c1202017 (diff)
downloadnumpy-a99c77577850054c55850102132dcf3eb43dea6a.tar.gz
Merge pull request #5548 from ahaldane/view_objects_owned
ENH: structured datatype safety checks
-rw-r--r--numpy/core/_internal.py168
-rw-r--r--numpy/core/src/multiarray/getset.c19
-rw-r--r--numpy/core/src/multiarray/methods.c43
-rw-r--r--numpy/core/tests/test_multiarray.py96
-rw-r--r--numpy/core/tests/test_records.py9
-rw-r--r--numpy/core/tests/test_regression.py5
-rw-r--r--numpy/lib/tests/test_recfunctions.py30
7 files changed, 332 insertions, 38 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index d32f59390..e80c22dfe 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -305,6 +305,174 @@ def _index_fields(ary, fields):
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
return array(view, dtype=copy_dtype, copy=True)
+def _get_all_field_offsets(dtype, base_offset=0):
+ """ Returns the types and offsets of all fields in a (possibly structured)
+ data type, including nested fields and subarrays.
+
+ Parameters
+ ----------
+ dtype : data-type
+ Data type to extract fields from.
+ base_offset : int, optional
+ Additional offset to add to all field offsets.
+
+ Returns
+ -------
+ fields : list of (data-type, int) pairs
+ A flat list of (dtype, byte offset) pairs.
+
+ """
+ fields = []
+ if dtype.fields is not None:
+ for name in dtype.names:
+ sub_dtype = dtype.fields[name][0]
+ sub_offset = dtype.fields[name][1] + base_offset
+ fields.extend(_get_all_field_offsets(sub_dtype, sub_offset))
+ else:
+ if dtype.shape:
+ sub_offsets = _get_all_field_offsets(dtype.base, base_offset)
+ count = 1
+ for dim in dtype.shape:
+ count *= dim
+ fields.extend((typ, off + dtype.base.itemsize*j)
+ for j in range(count) for (typ, off) in sub_offsets)
+ else:
+ fields.append((dtype, base_offset))
+ return fields
+
+def _check_field_overlap(new_fields, old_fields):
+ """ Perform object memory overlap tests for two data-types (see
+ _view_is_safe).
+
+ This function checks that new fields only access memory contained in old
+ fields, and that non-object fields are not interpreted as objects and vice
+ versa.
+
+ Parameters
+ ----------
+ new_fields : list of (data-type, int) pairs
+ Flat list of (dtype, byte offset) pairs for the new data type, as
+ returned by _get_all_field_offsets.
+ old_fields: list of (data-type, int) pairs
+ Flat list of (dtype, byte offset) pairs for the old data type, as
+ returned by _get_all_field_offsets.
+
+ Raises
+ ------
+ TypeError
+ If the new fields are incompatible with the old fields
+
+ """
+ from .numerictypes import object_
+ from .multiarray import dtype
+
+ #first go byte by byte and check we do not access bytes not in old_fields
+ new_bytes = set()
+ for tp, off in new_fields:
+ new_bytes.update(set(range(off, off+tp.itemsize)))
+ old_bytes = set()
+ for tp, off in old_fields:
+ old_bytes.update(set(range(off, off+tp.itemsize)))
+ if new_bytes.difference(old_bytes):
+ raise TypeError("view would access data parent array doesn't own")
+
+ #next check that we do not interpret non-Objects as Objects, and vv
+ obj_offsets = [off for (tp, off) in old_fields if tp.type is object_]
+ obj_size = dtype(object_).itemsize
+
+ for fld_dtype, fld_offset in new_fields:
+ if fld_dtype.type is object_:
+ # check we do not create object views where
+ # there are no objects.
+ if fld_offset not in obj_offsets:
+ raise TypeError("cannot view non-Object data as Object type")
+ else:
+ # next check we do not create non-object views
+ # where there are already objects.
+ # see validate_object_field_overlap for a similar computation.
+ for obj_offset in obj_offsets:
+ if (fld_offset < obj_offset + obj_size and
+ obj_offset < fld_offset + fld_dtype.itemsize):
+ raise TypeError("cannot view Object as non-Object type")
+
+def _getfield_is_safe(oldtype, newtype, offset):
+ """ Checks safety of getfield for object arrays.
+
+ As in _view_is_safe, we need to check that memory containing objects is not
+ reinterpreted as a non-object datatype and vice versa.
+
+ Parameters
+ ----------
+ oldtype : data-type
+ Data type of the original ndarray.
+ newtype : data-type
+ Data type of the field being accessed by ndarray.getfield
+ offset : int
+ Offset of the field being accessed by ndarray.getfield
+
+ Raises
+ ------
+ TypeError
+ If the field access is invalid
+
+ """
+ new_fields = _get_all_field_offsets(newtype, offset)
+ old_fields = _get_all_field_offsets(oldtype)
+ # raises if there is a problem
+ _check_field_overlap(new_fields, old_fields)
+
+def _view_is_safe(oldtype, newtype):
+ """ Checks safety of a view involving object arrays, for example when
+ doing::
+
+ np.zeros(10, dtype=oldtype).view(newtype)
+
+ We need to check that
+ 1) No memory that is not an object will be interpreted as a object,
+ 2) No memory containing an object will be interpreted as an arbitrary type.
+ Both cases can cause segfaults, eg in the case the view is written to.
+ Strategy here is to also disallow views where newtype has any field in a
+ place oldtype doesn't.
+
+ Parameters
+ ----------
+ oldtype : data-type
+ Data type of original ndarray
+ newtype : data-type
+ Data type of the view
+
+ Raises
+ ------
+ TypeError
+ If the new type is incompatible with the old type.
+
+ """
+ new_fields = _get_all_field_offsets(newtype)
+ new_size = newtype.itemsize
+
+ old_fields = _get_all_field_offsets(oldtype)
+ old_size = oldtype.itemsize
+
+ # if the itemsizes are not equal, we need to check that all the
+ # 'tiled positions' of the object match up. Here, we allow
+ # for arbirary itemsizes (even those possibly disallowed
+ # due to stride/data length issues).
+ if old_size == new_size:
+ new_num = old_num = 1
+ else:
+ gcd_new_old = _gcd(new_size, old_size)
+ new_num = old_size // gcd_new_old
+ old_num = new_size // gcd_new_old
+
+ # get position of fields within the tiling
+ new_fieldtile = [(tp, off + new_size*j)
+ for j in range(new_num) for (tp, off) in new_fields]
+ old_fieldtile = [(tp, off + old_size*j)
+ for j in range(old_num) for (tp, off) in old_fields]
+
+ # raises if there is a problem
+ _check_field_overlap(new_fieldtile, old_fieldtile)
+
# Given a string containing a PEP 3118 format specifier,
# construct a Numpy dtype
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index cbc798273..9ba12b092 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -9,8 +9,8 @@
#include "numpy/arrayobject.h"
#include "npy_config.h"
-
#include "npy_pycompat.h"
+#include "npy_import.h"
#include "common.h"
#include "scalartypes.h"
@@ -434,6 +434,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
npy_intp newdim;
int i;
char *msg = "new type not compatible with array.";
+ PyObject *safe;
+ static PyObject *checkfunc = NULL;
+
+ npy_cache_pyfunc("numpy.core._internal", "_view_is_safe", &checkfunc);
+ if (checkfunc == NULL) {
+ return -1;
+ }
if (arg == NULL) {
PyErr_SetString(PyExc_AttributeError,
@@ -448,15 +455,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
return -1;
}
- if (PyDataType_FLAGCHK(newtype, NPY_ITEM_HASOBJECT) ||
- PyDataType_FLAGCHK(newtype, NPY_ITEM_IS_POINTER) ||
- PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_HASOBJECT) ||
- PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_IS_POINTER)) {
- PyErr_SetString(PyExc_TypeError,
- "Cannot change data-type for object array.");
+ /* check that we are not reinterpreting memory containing Objects */
+ safe = PyObject_CallFunction(checkfunc, "OO", PyArray_DESCR(self), newtype);
+ if (safe == NULL) {
Py_DECREF(newtype);
return -1;
}
+ Py_DECREF(safe);
if (newtype->elsize == 0) {
/* Allow a void view */
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index a51453dd1..d06e4a512 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -10,6 +10,7 @@
#include "npy_config.h"
#include "npy_pycompat.h"
+#include "npy_import.h"
#include "ufunc_override.h"
#include "common.h"
#include "ctors.h"
@@ -358,15 +359,23 @@ NPY_NO_EXPORT PyObject *
PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
{
PyObject *ret = NULL;
+ PyObject *safe;
+ static PyObject *checkfunc = NULL;
- if (offset < 0 || (offset + typed->elsize) > PyArray_DESCR(self)->elsize) {
- PyErr_Format(PyExc_ValueError,
- "Need 0 <= offset <= %d for requested type "
- "but received offset = %d",
- PyArray_DESCR(self)->elsize-typed->elsize, offset);
- Py_DECREF(typed);
+ npy_cache_pyfunc("numpy.core._internal", "_getfield_is_safe", &checkfunc);
+ if (checkfunc == NULL) {
return NULL;
}
+
+ /* check that we are not reinterpreting memory containing Objects */
+ /* only returns True or raises */
+ safe = PyObject_CallFunction(checkfunc, "OOi", PyArray_DESCR(self),
+ typed, offset);
+ if (safe == NULL) {
+ return NULL;
+ }
+ Py_DECREF(safe);
+
ret = PyArray_NewFromDescr(Py_TYPE(self),
typed,
PyArray_NDIM(self), PyArray_DIMS(self),
@@ -417,23 +426,12 @@ PyArray_SetField(PyArrayObject *self, PyArray_Descr *dtype,
PyObject *ret = NULL;
int retval = 0;
- if (offset < 0 || (offset + dtype->elsize) > PyArray_DESCR(self)->elsize) {
- PyErr_Format(PyExc_ValueError,
- "Need 0 <= offset <= %d for requested type "
- "but received offset = %d",
- PyArray_DESCR(self)->elsize-dtype->elsize, offset);
- Py_DECREF(dtype);
- return -1;
- }
- ret = PyArray_NewFromDescr(Py_TYPE(self),
- dtype, PyArray_NDIM(self), PyArray_DIMS(self),
- PyArray_STRIDES(self), PyArray_BYTES(self) + offset,
- PyArray_FLAGS(self), (PyObject *)self);
+ /* getfield returns a view we can write to */
+ ret = PyArray_GetField(self, dtype, offset);
if (ret == NULL) {
return -1;
}
- PyArray_UpdateFlags((PyArrayObject *)ret, NPY_ARRAY_UPDATE_ALL);
retval = PyArray_CopyObject((PyArrayObject *)ret, val);
Py_DECREF(ret);
return retval;
@@ -455,13 +453,6 @@ array_setfield(PyArrayObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
- if (PyDataType_REFCHK(PyArray_DESCR(self))) {
- PyErr_SetString(PyExc_RuntimeError,
- "cannot call setfield on an object array");
- Py_DECREF(dtype);
- return NULL;
- }
-
if (PyArray_SetField(self, dtype, offset, value) < 0) {
return NULL;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index f37668de9..74c57e18a 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -809,6 +809,14 @@ class TestStructured(TestCase):
t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')]
assert_(not np.can_cast(a.dtype, t, casting=casting))
+ def test_objview(self):
+ # https://github.com/numpy/numpy/issues/3286
+ a = np.array([], dtype=[('a', 'f'), ('b', 'f'), ('c', 'O')])
+ a[['a', 'b']] # TypeError?
+
+ # https://github.com/numpy/numpy/issues/3253
+ dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')])
+ new2 = dat2[['B', 'A']] # TypeError?
class TestBool(TestCase):
def test_test_interning(self):
@@ -3577,8 +3585,9 @@ class TestRecord(TestCase):
class TestView(TestCase):
def test_basic(self):
- x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], dtype=[('r', np.int8), ('g', np.int8),
- ('b', np.int8), ('a', np.int8)])
+ x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)],
+ dtype=[('r', np.int8), ('g', np.int8),
+ ('b', np.int8), ('a', np.int8)])
# We must be specific about the endianness here:
y = x.view(dtype='<i4')
# ... and again without the keyword.
@@ -5523,6 +5532,89 @@ class TestHashing(TestCase):
x = np.array([])
self.assertFalse(isinstance(x, collections.Hashable))
+from numpy.core._internal import _view_is_safe
+
+class TestObjViewSafetyFuncs:
+ def test_view_safety(self):
+ psize = dtype('p').itemsize
+
+ # creates dtype but with extra character code - for missing 'p' fields
+ def mtype(s):
+ n, offset, fields = 0, 0, []
+ for c in s.split(','): #subarrays won't work
+ if c != '-':
+ fields.append(('f{0}'.format(n), c, offset))
+ n += 1
+ offset += dtype(c).itemsize if c != '-' else psize
+
+ names, formats, offsets = zip(*fields)
+ return dtype({'names': names, 'formats': formats,
+ 'offsets': offsets, 'itemsize': offset})
+
+ # test nonequal itemsizes with objects:
+ # these should succeed:
+ _view_is_safe(dtype('O,p,O,p'), dtype('O,p,O,p,O,p'))
+ _view_is_safe(dtype('O,O'), dtype('O,O,O'))
+ # these should fail:
+ assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,O'))
+ assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,p'))
+ assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('p,O'))
+
+ # test nonequal itemsizes with missing fields:
+ # these should succeed:
+ _view_is_safe(mtype('-,p,-,p'), mtype('-,p,-,p,-,p'))
+ _view_is_safe(dtype('p,p'), dtype('p,p,p'))
+ # these should fail:
+ assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,p'))
+ assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,-'))
+ assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('-,p'))
+
+ # scans through positions at which we can view a type
+ def scanView(d1, otype):
+ goodpos = []
+ for shift in range(d1.itemsize - dtype(otype).itemsize+1):
+ d2 = dtype({'names': ['f0'], 'formats': [otype],
+ 'offsets': [shift], 'itemsize': d1.itemsize})
+ try:
+ _view_is_safe(d1, d2)
+ except TypeError:
+ pass
+ else:
+ goodpos.append(shift)
+ return goodpos
+
+ # test partial overlap with object field
+ assert_equal(scanView(dtype('p,O,p,p,O,O'), 'p'),
+ [0] + list(range(2*psize, 3*psize+1)))
+ assert_equal(scanView(dtype('p,O,p,p,O,O'), 'O'),
+ [psize, 4*psize, 5*psize])
+
+ # test partial overlap with missing field
+ assert_equal(scanView(mtype('p,-,p,p,-,-'), 'p'),
+ [0] + list(range(2*psize, 3*psize+1)))
+
+ # test nested structures with objects:
+ nestedO = dtype([('f0', 'p'), ('f1', 'p,O,p')])
+ assert_equal(scanView(nestedO, 'p'), list(range(psize+1)) + [3*psize])
+ assert_equal(scanView(nestedO, 'O'), [2*psize])
+
+ # test nested structures with missing fields:
+ nestedM = dtype([('f0', 'p'), ('f1', mtype('p,-,p'))])
+ assert_equal(scanView(nestedM, 'p'), list(range(psize+1)) + [3*psize])
+
+ # test subarrays with objects
+ subarrayO = dtype('p,(2,3)O,p')
+ assert_equal(scanView(subarrayO, 'p'), [0, 7*psize])
+ assert_equal(scanView(subarrayO, 'O'),
+ list(range(psize, 6*psize+1, psize)))
+
+ #test dtype with overlapping fields
+ overlapped = dtype({'names': ['f0', 'f1', 'f2', 'f3'],
+ 'formats': ['p', 'p', 'p', 'p'],
+ 'offsets': [0, 1, 3*psize-1, 3*psize],
+ 'itemsize': 4*psize})
+ assert_equal(scanView(overlapped, 'p'), [0, 1, 3*psize-1, 3*psize])
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index 2f86b7ccb..b07b3c876 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -219,6 +219,15 @@ class TestRecord(TestCase):
assert_equal(a, pickle.loads(pickle.dumps(a)))
assert_equal(a[0], pickle.loads(pickle.dumps(a[0])))
+ def test_objview_record(self):
+ # https://github.com/numpy/numpy/issues/2599
+ dt = np.dtype([('foo', 'i8'), ('bar', 'O')])
+ r = np.zeros((1,3), dtype=dt).view(np.recarray)
+ r.foo = np.array([1, 2, 3]) # TypeError?
+
+ # https://github.com/numpy/numpy/issues/3256
+ ra = np.recarray((2,), dtype=[('x', object), ('y', float), ('z', int)])
+ ra[['x','y']] #TypeError?
def test_find_duplicate():
l1 = [1, 2, 3, 4, 5, 6]
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 9e8511a01..399470214 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1739,10 +1739,9 @@ class TestRegression(TestCase):
assert_equal(np.add.accumulate(x[:-1, 0]), [])
def test_objectarray_setfield(self):
- # Setfield directly manipulates the raw array data,
- # so is invalid for object arrays.
+ # Setfield should not overwrite Object fields with non-Object data
x = np.array([1, 2, 3], dtype=object)
- assert_raises(RuntimeError, x.setfield, 4, np.int32, 0)
+ assert_raises(TypeError, x.setfield, 4, np.int32, 0)
def test_setting_rank0_string(self):
"Ticket #1736"
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index 51a2077eb..13e75cbd0 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -700,6 +700,36 @@ class TestJoinBy2(TestCase):
assert_equal(test.dtype, control.dtype)
assert_equal(test, control)
+class TestAppendFieldsObj(TestCase):
+ """
+ Test append_fields with arrays containing objects
+ """
+ # https://github.com/numpy/numpy/issues/2346
+
+ def setUp(self):
+ from datetime import date
+ self.data = dict(obj=date(2000, 1, 1))
+
+ def test_append_to_objects(self):
+ "Test append_fields when the base array contains objects"
+ obj = self.data['obj']
+ x = np.array([(obj, 1.), (obj, 2.)],
+ dtype=[('A', object), ('B', float)])
+ y = np.array([10, 20], dtype=int)
+ test = append_fields(x, 'C', data=y, usemask=False)
+ control = np.array([(obj, 1.0, 10), (obj, 2.0, 20)],
+ dtype=[('A', object), ('B', float), ('C', int)])
+ assert_equal(test, control)
+
+ def test_append_with_objects(self):
+ "Test append_fields when the appended data contains objects"
+ obj = self.data['obj']
+ x = np.array([(10, 1.), (20, 2.)], dtype=[('A', int), ('B', float)])
+ y = np.array([obj, obj], dtype=object)
+ test = append_fields(x, 'C', data=y, dtypes=object, usemask=False)
+ control = np.array([(10, 1.0, obj), (20, 2.0, obj)],
+ dtype=[('A', int), ('B', float), ('C', object)])
+ assert_equal(test, control)
if __name__ == '__main__':
run_module_suite()