diff options
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r-- | numpy/core/_internal.py | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index d32f59390..e80c22dfe 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -305,6 +305,174 @@ def _index_fields(ary, fields): copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']} return array(view, dtype=copy_dtype, copy=True) +def _get_all_field_offsets(dtype, base_offset=0): + """ Returns the types and offsets of all fields in a (possibly structured) + data type, including nested fields and subarrays. + + Parameters + ---------- + dtype : data-type + Data type to extract fields from. + base_offset : int, optional + Additional offset to add to all field offsets. + + Returns + ------- + fields : list of (data-type, int) pairs + A flat list of (dtype, byte offset) pairs. + + """ + fields = [] + if dtype.fields is not None: + for name in dtype.names: + sub_dtype = dtype.fields[name][0] + sub_offset = dtype.fields[name][1] + base_offset + fields.extend(_get_all_field_offsets(sub_dtype, sub_offset)) + else: + if dtype.shape: + sub_offsets = _get_all_field_offsets(dtype.base, base_offset) + count = 1 + for dim in dtype.shape: + count *= dim + fields.extend((typ, off + dtype.base.itemsize*j) + for j in range(count) for (typ, off) in sub_offsets) + else: + fields.append((dtype, base_offset)) + return fields + +def _check_field_overlap(new_fields, old_fields): + """ Perform object memory overlap tests for two data-types (see + _view_is_safe). + + This function checks that new fields only access memory contained in old + fields, and that non-object fields are not interpreted as objects and vice + versa. + + Parameters + ---------- + new_fields : list of (data-type, int) pairs + Flat list of (dtype, byte offset) pairs for the new data type, as + returned by _get_all_field_offsets. + old_fields: list of (data-type, int) pairs + Flat list of (dtype, byte offset) pairs for the old data type, as + returned by _get_all_field_offsets. + + Raises + ------ + TypeError + If the new fields are incompatible with the old fields + + """ + from .numerictypes import object_ + from .multiarray import dtype + + #first go byte by byte and check we do not access bytes not in old_fields + new_bytes = set() + for tp, off in new_fields: + new_bytes.update(set(range(off, off+tp.itemsize))) + old_bytes = set() + for tp, off in old_fields: + old_bytes.update(set(range(off, off+tp.itemsize))) + if new_bytes.difference(old_bytes): + raise TypeError("view would access data parent array doesn't own") + + #next check that we do not interpret non-Objects as Objects, and vv + obj_offsets = [off for (tp, off) in old_fields if tp.type is object_] + obj_size = dtype(object_).itemsize + + for fld_dtype, fld_offset in new_fields: + if fld_dtype.type is object_: + # check we do not create object views where + # there are no objects. + if fld_offset not in obj_offsets: + raise TypeError("cannot view non-Object data as Object type") + else: + # next check we do not create non-object views + # where there are already objects. + # see validate_object_field_overlap for a similar computation. + for obj_offset in obj_offsets: + if (fld_offset < obj_offset + obj_size and + obj_offset < fld_offset + fld_dtype.itemsize): + raise TypeError("cannot view Object as non-Object type") + +def _getfield_is_safe(oldtype, newtype, offset): + """ Checks safety of getfield for object arrays. + + As in _view_is_safe, we need to check that memory containing objects is not + reinterpreted as a non-object datatype and vice versa. + + Parameters + ---------- + oldtype : data-type + Data type of the original ndarray. + newtype : data-type + Data type of the field being accessed by ndarray.getfield + offset : int + Offset of the field being accessed by ndarray.getfield + + Raises + ------ + TypeError + If the field access is invalid + + """ + new_fields = _get_all_field_offsets(newtype, offset) + old_fields = _get_all_field_offsets(oldtype) + # raises if there is a problem + _check_field_overlap(new_fields, old_fields) + +def _view_is_safe(oldtype, newtype): + """ Checks safety of a view involving object arrays, for example when + doing:: + + np.zeros(10, dtype=oldtype).view(newtype) + + We need to check that + 1) No memory that is not an object will be interpreted as a object, + 2) No memory containing an object will be interpreted as an arbitrary type. + Both cases can cause segfaults, eg in the case the view is written to. + Strategy here is to also disallow views where newtype has any field in a + place oldtype doesn't. + + Parameters + ---------- + oldtype : data-type + Data type of original ndarray + newtype : data-type + Data type of the view + + Raises + ------ + TypeError + If the new type is incompatible with the old type. + + """ + new_fields = _get_all_field_offsets(newtype) + new_size = newtype.itemsize + + old_fields = _get_all_field_offsets(oldtype) + old_size = oldtype.itemsize + + # if the itemsizes are not equal, we need to check that all the + # 'tiled positions' of the object match up. Here, we allow + # for arbirary itemsizes (even those possibly disallowed + # due to stride/data length issues). + if old_size == new_size: + new_num = old_num = 1 + else: + gcd_new_old = _gcd(new_size, old_size) + new_num = old_size // gcd_new_old + old_num = new_size // gcd_new_old + + # get position of fields within the tiling + new_fieldtile = [(tp, off + new_size*j) + for j in range(new_num) for (tp, off) in new_fields] + old_fieldtile = [(tp, off + old_size*j) + for j in range(old_num) for (tp, off) in old_fields] + + # raises if there is a problem + _check_field_overlap(new_fieldtile, old_fieldtile) + # Given a string containing a PEP 3118 format specifier, # construct a Numpy dtype |