diff options
author | Allan Haldane <allan.haldane@gmail.com> | 2015-03-05 16:58:18 -0500 |
---|---|---|
committer | Allan Haldane <allan.haldane@gmail.com> | 2015-06-17 13:51:43 -0400 |
commit | 3c1a13dea6a7e189675977ad65ea230ce4816061 (patch) | |
tree | b9a55bd90db5f57470fdd4d3267bc9862e5dbae1 /numpy/core/_internal.py | |
parent | 8c86a0a879a9f6d8bc9b225e95512fd7f2fca964 (diff) | |
download | numpy-3c1a13dea6a7e189675977ad65ea230ce4816061.tar.gz |
ENH: simplify field indexing of structured arrays
This commit simplifies the code in array_subscript and
array_assign_subscript related to field access. This fixes #4806,
and also removes a potential segfaults, eg if the array is indexed using
an sequence-like object that raises an exception in getitem.
Also fixes #5631, related to creation of structured dtypes
with no fields (an unusual and probably useless edge case).
Also moves all imports in _internal.py to the top.
Fixes #4806.
Fixes #5631.
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r-- | numpy/core/_internal.py | 75 |
1 files changed, 49 insertions, 26 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index e80c22dfe..a20bf10e4 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -10,7 +10,10 @@ import re import sys import warnings -from numpy.compat import asbytes, bytes +from numpy.compat import asbytes, bytes, basestring +from .multiarray import dtype, array, ndarray +import ctypes +from .numerictypes import object_ if (sys.byteorder == 'little'): _nbo = asbytes('<') @@ -18,7 +21,6 @@ else: _nbo = asbytes('>') def _makenames_list(adict, align): - from .multiarray import dtype allfields = [] fnames = list(adict.keys()) for fname in fnames: @@ -52,7 +54,6 @@ def _makenames_list(adict, align): # a dictionary without "names" and "formats" # fields is used as a data-type descriptor. def _usefields(adict, align): - from .multiarray import dtype try: names = adict[-1] except KeyError: @@ -130,7 +131,6 @@ def _array_descr(descriptor): # so don't remove the name here, or you'll # break backward compatibilty. def _reconstruct(subtype, shape, dtype): - from .multiarray import ndarray return ndarray.__new__(subtype, shape, dtype) @@ -194,12 +194,10 @@ def _commastring(astr): return result def _getintp_ctype(): - from .multiarray import dtype val = _getintp_ctype.cache if val is not None: return val char = dtype('p').char - import ctypes if (char == 'i'): val = ctypes.c_int elif char == 'l': @@ -224,7 +222,6 @@ class _missing_ctypes(object): class _ctypes(object): def __init__(self, array, ptr=None): try: - import ctypes self._ctypes = ctypes except ImportError: self._ctypes = _missing_ctypes() @@ -287,23 +284,55 @@ def _newnames(datatype, order): return tuple(list(order) + nameslist) raise ValueError("unsupported order value: %s" % (order,)) -# Given an array with fields and a sequence of field names -# construct a new array with just those fields copied over -def _index_fields(ary, fields): - from .multiarray import empty, dtype, array +def _index_fields(ary, names): + """ Given a structured array and a sequence of field names + construct new array with just those fields. + + Parameters + ---------- + ary : ndarray + Structured array being subscripted + names : string or list of strings + Either a single field name, or a list of field names + + Returns + ------- + sub_ary : ndarray + If `names` is a single field name, the return value is identical to + ary.getfield, a writeable view into `ary`. If `names` is a list of + field names the return value is a copy of `ary` containing only those + fields. This is planned to return a view in the future. + + Raises + ------ + ValueError + If `ary` does not contain a field given in `names`. + + """ dt = ary.dtype - names = [name for name in fields if name in dt.names] - formats = [dt.fields[name][0] for name in fields if name in dt.names] - offsets = [dt.fields[name][1] for name in fields if name in dt.names] + #use getfield to index a single field + if isinstance(names, basestring): + try: + return ary.getfield(dt.fields[names][0], dt.fields[names][1]) + except KeyError: + raise ValueError("no field of name %s" % names) + + for name in names: + if name not in dt.fields: + raise ValueError("no field of name %s" % name) - view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize} - view = ary.view(dtype=view_dtype) + formats = [dt.fields[name][0] for name in names] + offsets = [dt.fields[name][1] for name in names] + + view_dtype = {'names': names, 'formats': formats, + 'offsets': offsets, 'itemsize': dt.itemsize} + + # return copy for now (future plan to return ary.view(dtype=view_dtype)) + copy_dtype = {'names': view_dtype['names'], + 'formats': view_dtype['formats']} + return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True) - # Return a copy for now until behavior is fully deprecated - # in favor of returning view - copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']} - return array(view, dtype=copy_dtype, copy=True) def _get_all_field_offsets(dtype, base_offset=0): """ Returns the types and offsets of all fields in a (possibly structured) @@ -363,8 +392,6 @@ def _check_field_overlap(new_fields, old_fields): If the new fields are incompatible with the old fields """ - from .numerictypes import object_ - from .multiarray import dtype #first go byte by byte and check we do not access bytes not in old_fields new_bytes = set() @@ -527,8 +554,6 @@ _pep3118_standard_map = { _pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys()) def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): - from numpy.core.multiarray import dtype - fields = {} offset = 0 explicit_name = False @@ -694,8 +719,6 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): def _add_trailing_padding(value, padding): """Inject the specified number of padding bytes at the end of a dtype""" - from numpy.core.multiarray import dtype - if value.fields is None: vfields = {'f0': (value, 0)} else: |