summaryrefslogtreecommitdiff
path: root/numpy/core/_internal.py
diff options
context:
space:
mode:
authorAllan Haldane <allan.haldane@gmail.com>2015-03-05 16:58:18 -0500
committerAllan Haldane <allan.haldane@gmail.com>2015-06-17 13:51:43 -0400
commit3c1a13dea6a7e189675977ad65ea230ce4816061 (patch)
treeb9a55bd90db5f57470fdd4d3267bc9862e5dbae1 /numpy/core/_internal.py
parent8c86a0a879a9f6d8bc9b225e95512fd7f2fca964 (diff)
downloadnumpy-3c1a13dea6a7e189675977ad65ea230ce4816061.tar.gz
ENH: simplify field indexing of structured arrays
This commit simplifies the code in array_subscript and array_assign_subscript related to field access. This fixes #4806, and also removes a potential segfaults, eg if the array is indexed using an sequence-like object that raises an exception in getitem. Also fixes #5631, related to creation of structured dtypes with no fields (an unusual and probably useless edge case). Also moves all imports in _internal.py to the top. Fixes #4806. Fixes #5631.
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r--numpy/core/_internal.py75
1 files changed, 49 insertions, 26 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index e80c22dfe..a20bf10e4 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -10,7 +10,10 @@ import re
import sys
import warnings
-from numpy.compat import asbytes, bytes
+from numpy.compat import asbytes, bytes, basestring
+from .multiarray import dtype, array, ndarray
+import ctypes
+from .numerictypes import object_
if (sys.byteorder == 'little'):
_nbo = asbytes('<')
@@ -18,7 +21,6 @@ else:
_nbo = asbytes('>')
def _makenames_list(adict, align):
- from .multiarray import dtype
allfields = []
fnames = list(adict.keys())
for fname in fnames:
@@ -52,7 +54,6 @@ def _makenames_list(adict, align):
# a dictionary without "names" and "formats"
# fields is used as a data-type descriptor.
def _usefields(adict, align):
- from .multiarray import dtype
try:
names = adict[-1]
except KeyError:
@@ -130,7 +131,6 @@ def _array_descr(descriptor):
# so don't remove the name here, or you'll
# break backward compatibilty.
def _reconstruct(subtype, shape, dtype):
- from .multiarray import ndarray
return ndarray.__new__(subtype, shape, dtype)
@@ -194,12 +194,10 @@ def _commastring(astr):
return result
def _getintp_ctype():
- from .multiarray import dtype
val = _getintp_ctype.cache
if val is not None:
return val
char = dtype('p').char
- import ctypes
if (char == 'i'):
val = ctypes.c_int
elif char == 'l':
@@ -224,7 +222,6 @@ class _missing_ctypes(object):
class _ctypes(object):
def __init__(self, array, ptr=None):
try:
- import ctypes
self._ctypes = ctypes
except ImportError:
self._ctypes = _missing_ctypes()
@@ -287,23 +284,55 @@ def _newnames(datatype, order):
return tuple(list(order) + nameslist)
raise ValueError("unsupported order value: %s" % (order,))
-# Given an array with fields and a sequence of field names
-# construct a new array with just those fields copied over
-def _index_fields(ary, fields):
- from .multiarray import empty, dtype, array
+def _index_fields(ary, names):
+ """ Given a structured array and a sequence of field names
+ construct new array with just those fields.
+
+ Parameters
+ ----------
+ ary : ndarray
+ Structured array being subscripted
+ names : string or list of strings
+ Either a single field name, or a list of field names
+
+ Returns
+ -------
+ sub_ary : ndarray
+ If `names` is a single field name, the return value is identical to
+ ary.getfield, a writeable view into `ary`. If `names` is a list of
+ field names the return value is a copy of `ary` containing only those
+ fields. This is planned to return a view in the future.
+
+ Raises
+ ------
+ ValueError
+ If `ary` does not contain a field given in `names`.
+
+ """
dt = ary.dtype
- names = [name for name in fields if name in dt.names]
- formats = [dt.fields[name][0] for name in fields if name in dt.names]
- offsets = [dt.fields[name][1] for name in fields if name in dt.names]
+ #use getfield to index a single field
+ if isinstance(names, basestring):
+ try:
+ return ary.getfield(dt.fields[names][0], dt.fields[names][1])
+ except KeyError:
+ raise ValueError("no field of name %s" % names)
+
+ for name in names:
+ if name not in dt.fields:
+ raise ValueError("no field of name %s" % name)
- view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize}
- view = ary.view(dtype=view_dtype)
+ formats = [dt.fields[name][0] for name in names]
+ offsets = [dt.fields[name][1] for name in names]
+
+ view_dtype = {'names': names, 'formats': formats,
+ 'offsets': offsets, 'itemsize': dt.itemsize}
+
+ # return copy for now (future plan to return ary.view(dtype=view_dtype))
+ copy_dtype = {'names': view_dtype['names'],
+ 'formats': view_dtype['formats']}
+ return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True)
- # Return a copy for now until behavior is fully deprecated
- # in favor of returning view
- copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
- return array(view, dtype=copy_dtype, copy=True)
def _get_all_field_offsets(dtype, base_offset=0):
""" Returns the types and offsets of all fields in a (possibly structured)
@@ -363,8 +392,6 @@ def _check_field_overlap(new_fields, old_fields):
If the new fields are incompatible with the old fields
"""
- from .numerictypes import object_
- from .multiarray import dtype
#first go byte by byte and check we do not access bytes not in old_fields
new_bytes = set()
@@ -527,8 +554,6 @@ _pep3118_standard_map = {
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())
def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
- from numpy.core.multiarray import dtype
-
fields = {}
offset = 0
explicit_name = False
@@ -694,8 +719,6 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
def _add_trailing_padding(value, padding):
"""Inject the specified number of padding bytes at the end of a dtype"""
- from numpy.core.multiarray import dtype
-
if value.fields is None:
vfields = {'f0': (value, 0)}
else: