summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/1.17.0-notes.rst17
-rw-r--r--numpy/core/src/multiarray/descriptor.c117
-rw-r--r--numpy/core/src/multiarray/descriptor.h12
-rw-r--r--numpy/core/src/multiarray/mapping.c100
-rw-r--r--numpy/core/tests/test_dtype.py55
-rw-r--r--numpy/core/tests/test_records.py2
6 files changed, 216 insertions, 87 deletions
diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst
index 1fa6fb9d6..0007fe1a6 100644
--- a/doc/release/1.17.0-notes.rst
+++ b/doc/release/1.17.0-notes.rst
@@ -240,6 +240,18 @@ Floating point scalars implement ``as_integer_ratio`` to match the builtin float
This returns a (numerator, denominator) pair, which can be used to construct a
`fractions.Fraction`.
+structured ``dtype`` objects can be indexed with multiple fields names
+----------------------------------------------------------------------
+``arr.dtype[['a', 'b']]`` now returns a dtype that is equivalent to
+``arr[['a', 'b']].dtype``, for consistency with
+``arr.dtype['a'] == arr['a'].dtype``.
+
+Like the dtype of structured arrays indexed with a list of fields, this dtype
+has the same `itemsize` as the original, but only keeps a subset of the fields.
+
+This means that `arr[['a', 'b']]` and ``arr.view(arr.dtype[['a', 'b']])`` are
+equivalent.
+
``.npy`` files support unicode field names
------------------------------------------
A new format version of 3.0 has been introduced, which enables structured types
@@ -435,4 +447,9 @@ Additionally, there are some corner cases with behavior changes:
------------------------------------------------------
The interface may use an ``offset`` value that was mistakenly ignored.
+Structured arrays indexed with non-existent fields raise ``KeyError`` not ``ValueError``
+----------------------------------------------------------------------------------------
+``arr['bad_field']`` on a structured type raises ``KeyError``, for consistency
+with ``dict['bad_field']``.
+
.. _`NEP 18` : http://www.numpy.org/neps/nep-0018-array-function-protocol.html
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index e3e0da24a..24716aecf 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -3370,6 +3370,117 @@ _subscript_by_index(PyArray_Descr *self, Py_ssize_t i)
return ret;
}
+static npy_bool
+_is_list_of_strings(PyObject *obj)
+{
+ int seqlen, i;
+ if (!PyList_CheckExact(obj)) {
+ return NPY_FALSE;
+ }
+ seqlen = PyList_GET_SIZE(obj);
+ for (i = 0; i < seqlen; i++) {
+ PyObject *item = PyList_GET_ITEM(obj, i);
+ if (!PyBaseString_Check(item)) {
+ return NPY_FALSE;
+ }
+ }
+
+ return NPY_TRUE;
+}
+
+NPY_NO_EXPORT PyArray_Descr *
+arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind)
+{
+ int seqlen, i;
+ PyObject *fields = NULL;
+ PyObject *names = NULL;
+ PyArray_Descr *view_dtype;
+
+ seqlen = PySequence_Size(ind);
+ if (seqlen == -1) {
+ return NULL;
+ }
+
+ fields = PyDict_New();
+ if (fields == NULL) {
+ goto fail;
+ }
+ names = PyTuple_New(seqlen);
+ if (names == NULL) {
+ goto fail;
+ }
+
+ for (i = 0; i < seqlen; i++) {
+ PyObject *name;
+ PyObject *tup;
+
+ name = PySequence_GetItem(ind, i);
+ if (name == NULL) {
+ goto fail;
+ }
+
+ /* Let the names tuple steal a reference now, so we don't need to
+ * decref name if an error occurs further on.
+ */
+ PyTuple_SET_ITEM(names, i, name);
+
+ tup = PyDict_GetItem(self->fields, name);
+ if (tup == NULL) {
+ PyErr_SetObject(PyExc_KeyError, name);
+ goto fail;
+ }
+
+ /* disallow use of titles as index */
+ if (PyTuple_Size(tup) == 3) {
+ PyObject *title = PyTuple_GET_ITEM(tup, 2);
+ int titlecmp = PyObject_RichCompareBool(title, name, Py_EQ);
+ if (titlecmp < 0) {
+ goto fail;
+ }
+ if (titlecmp == 1) {
+ /* if title == name, we were given a title, not a field name */
+ PyErr_SetString(PyExc_KeyError,
+ "cannot use field titles in multi-field index");
+ goto fail;
+ }
+ if (PyDict_SetItem(fields, title, tup) < 0) {
+ goto fail;
+ }
+ }
+ /* disallow duplicate field indices */
+ if (PyDict_Contains(fields, name)) {
+ PyObject *msg = NULL;
+ PyObject *fmt = PyUString_FromString(
+ "duplicate field of name {!r}");
+ if (fmt != NULL) {
+ msg = PyObject_CallMethod(fmt, "format", "O", name);
+ Py_DECREF(fmt);
+ }
+ PyErr_SetObject(PyExc_ValueError, msg);
+ Py_XDECREF(msg);
+ goto fail;
+ }
+ if (PyDict_SetItem(fields, name, tup) < 0) {
+ goto fail;
+ }
+ }
+
+ view_dtype = PyArray_DescrNewFromType(NPY_VOID);
+ if (view_dtype == NULL) {
+ goto fail;
+ }
+ view_dtype->elsize = self->elsize;
+ view_dtype->names = names;
+ view_dtype->fields = fields;
+ view_dtype->flags = self->flags;
+ return view_dtype;
+
+fail:
+ Py_XDECREF(fields);
+ Py_XDECREF(names);
+ return NULL;
+}
+
static PyObject *
descr_subscript(PyArray_Descr *self, PyObject *op)
{
@@ -3380,6 +3491,9 @@ descr_subscript(PyArray_Descr *self, PyObject *op)
if (PyBaseString_Check(op)) {
return _subscript_by_name(self, op);
}
+ else if (_is_list_of_strings(op)) {
+ return (PyObject *)arraydescr_field_subset_view(self, op);
+ }
else {
Py_ssize_t i = PyArray_PyIntAsIntp(op);
if (error_converting(i)) {
@@ -3387,7 +3501,8 @@ descr_subscript(PyArray_Descr *self, PyObject *op)
PyObject *err = PyErr_Occurred();
if (PyErr_GivenExceptionMatches(err, PyExc_TypeError)) {
PyErr_SetString(PyExc_TypeError,
- "Field key must be an integer, string, or unicode.");
+ "Field key must be an integer field offset, "
+ "single field name, or list of field names.");
}
return NULL;
}
diff --git a/numpy/core/src/multiarray/descriptor.h b/numpy/core/src/multiarray/descriptor.h
index a5f3b8cdf..de7497363 100644
--- a/numpy/core/src/multiarray/descriptor.h
+++ b/numpy/core/src/multiarray/descriptor.h
@@ -14,6 +14,18 @@ _arraydescr_from_dtype_attr(PyObject *obj);
NPY_NO_EXPORT int
is_dtype_struct_simple_unaligned_layout(PyArray_Descr *dtype);
+/*
+ * Filter the fields of a dtype to only those in the list of strings, ind.
+ *
+ * No type checking is performed on the input.
+ *
+ * Raises:
+ * ValueError - if a field is repeated
+ * KeyError - if an invalid field name (or any field title) is used
+ */
+NPY_NO_EXPORT PyArray_Descr *
+arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind);
+
extern NPY_NO_EXPORT char *_datetime_strings[];
#endif
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c
index 7c63087ff..9e54a2c64 100644
--- a/numpy/core/src/multiarray/mapping.c
+++ b/numpy/core/src/multiarray/mapping.c
@@ -15,6 +15,7 @@
#include "common.h"
#include "ctors.h"
+#include "descriptor.h"
#include "iterators.h"
#include "mapping.h"
#include "lowlevel_strided_loops.h"
@@ -1395,9 +1396,9 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op)
/*
* Attempts to subscript an array using a field name or list of field names.
*
- * If an error occurred, return 0 and set view to NULL. If the subscript is not
- * a string or list of strings, return -1 and set view to NULL. Otherwise
- * return 0 and set view to point to a new view into arr for the given fields.
+ * ret = 0, view != NULL: view points to the requested fields of arr
+ * ret = 0, view == NULL: an error occurred
+ * ret = -1, view == NULL: unrecognized input, this is not a field index.
*/
NPY_NO_EXPORT int
_get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view)
@@ -1440,111 +1441,44 @@ _get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view)
}
return 0;
}
+
/* next check for a list of field names */
else if (PySequence_Check(ind) && !PyTuple_Check(ind)) {
- int seqlen, i;
- PyObject *name = NULL, *tup;
- PyObject *fields, *names;
+ npy_intp seqlen, i;
PyArray_Descr *view_dtype;
seqlen = PySequence_Size(ind);
- /* quit if have a 0-d array (seqlen==-1) or a 0-len array */
+ /* quit if have a fake sequence-like, which errors on len()*/
if (seqlen == -1) {
PyErr_Clear();
return -1;
}
+ /* 0-len list is handled elsewhere as an integer index */
if (seqlen == 0) {
return -1;
}
- fields = PyDict_New();
- if (fields == NULL) {
- return 0;
- }
- names = PyTuple_New(seqlen);
- if (names == NULL) {
- Py_DECREF(fields);
- return 0;
- }
-
+ /* check the items are strings */
for (i = 0; i < seqlen; i++) {
- name = PySequence_GetItem(ind, i);
- if (name == NULL) {
- /* only happens for strange sequence objects */
+ npy_bool is_string;
+ PyObject *item = PySequence_GetItem(ind, i);
+ if (item == NULL) {
PyErr_Clear();
- Py_DECREF(fields);
- Py_DECREF(names);
return -1;
}
-
- if (!PyBaseString_Check(name)) {
- Py_DECREF(name);
- Py_DECREF(fields);
- Py_DECREF(names);
+ is_string = PyBaseString_Check(item);
+ Py_DECREF(item);
+ if (!is_string) {
return -1;
}
-
- tup = PyDict_GetItem(PyArray_DESCR(arr)->fields, name);
- if (tup == NULL){
- PyObject *errmsg = PyUString_FromString("no field of name ");
- PyUString_ConcatAndDel(&errmsg, name);
- PyErr_SetObject(PyExc_ValueError, errmsg);
- Py_DECREF(errmsg);
- Py_DECREF(fields);
- Py_DECREF(names);
- return 0;
- }
- /* disallow use of titles as index */
- if (PyTuple_Size(tup) == 3) {
- PyObject *title = PyTuple_GET_ITEM(tup, 2);
- int titlecmp = PyObject_RichCompareBool(title, name, Py_EQ);
- if (titlecmp == 1) {
- /* if title == name, we got a title, not a field name */
- PyErr_SetString(PyExc_KeyError,
- "cannot use field titles in multi-field index");
- }
- if (titlecmp != 0 || PyDict_SetItem(fields, title, tup) < 0) {
- Py_DECREF(name);
- Py_DECREF(fields);
- Py_DECREF(names);
- return 0;
- }
- }
- /* disallow duplicate field indices */
- if (PyDict_Contains(fields, name)) {
- PyObject *errmsg = PyUString_FromString(
- "duplicate field of name ");
- PyUString_ConcatAndDel(&errmsg, name);
- PyErr_SetObject(PyExc_ValueError, errmsg);
- Py_DECREF(errmsg);
- Py_DECREF(fields);
- Py_DECREF(names);
- return 0;
- }
- if (PyDict_SetItem(fields, name, tup) < 0) {
- Py_DECREF(name);
- Py_DECREF(fields);
- Py_DECREF(names);
- return 0;
- }
- if (PyTuple_SetItem(names, i, name) < 0) {
- Py_DECREF(fields);
- Py_DECREF(names);
- return 0;
- }
}
- view_dtype = PyArray_DescrNewFromType(NPY_VOID);
+ /* Call into the dtype subscript */
+ view_dtype = arraydescr_field_subset_view(PyArray_DESCR(arr), ind);
if (view_dtype == NULL) {
- Py_DECREF(fields);
- Py_DECREF(names);
return 0;
}
- view_dtype->elsize = PyArray_DESCR(arr)->elsize;
- view_dtype->names = names;
- view_dtype->fields = fields;
- view_dtype->flags = PyArray_DESCR(arr)->flags;
*view = (PyArrayObject*)PyArray_NewFromDescr_int(
Py_TYPE(arr),
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index f4736d694..8f33a8daf 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -316,15 +316,66 @@ class TestRecord(object):
assert_raises(IndexError, lambda: dt[-3])
assert_raises(TypeError, operator.getitem, dt, 3.0)
- assert_raises(TypeError, operator.getitem, dt, [])
assert_equal(dt[1], dt[np.int8(1)])
+ @pytest.mark.parametrize('align_flag',[False, True])
+ def test_multifield_index(self, align_flag):
+ # indexing with a list produces subfields
+ # the align flag should be preserved
+ dt = np.dtype([
+ (('title', 'col1'), '<U20'), ('A', '<f8'), ('B', '<f8')
+ ], align=align_flag)
+
+ dt_sub = dt[['B', 'col1']]
+ assert_equal(
+ dt_sub,
+ np.dtype({
+ 'names': ['B', 'col1'],
+ 'formats': ['<f8', '<U20'],
+ 'offsets': [88, 0],
+ 'titles': [None, 'title'],
+ 'itemsize': 96
+ })
+ )
+ assert_equal(dt_sub.isalignedstruct, align_flag)
+
+ dt_sub = dt[['B']]
+ assert_equal(
+ dt_sub,
+ np.dtype({
+ 'names': ['B'],
+ 'formats': ['<f8'],
+ 'offsets': [88],
+ 'itemsize': 96
+ })
+ )
+ assert_equal(dt_sub.isalignedstruct, align_flag)
+
+ dt_sub = dt[[]]
+ assert_equal(
+ dt_sub,
+ np.dtype({
+ 'names': [],
+ 'formats': [],
+ 'offsets': [],
+ 'itemsize': 96
+ })
+ )
+ assert_equal(dt_sub.isalignedstruct, align_flag)
+
+ assert_raises(TypeError, operator.getitem, dt, ())
+ assert_raises(TypeError, operator.getitem, dt, [1, 2, 3])
+ assert_raises(TypeError, operator.getitem, dt, ['col1', 2])
+ assert_raises(KeyError, operator.getitem, dt, ['fake'])
+ assert_raises(KeyError, operator.getitem, dt, ['title'])
+ assert_raises(ValueError, operator.getitem, dt, ['col1', 'col1'])
+
def test_partial_dict(self):
# 'names' is missing
assert_raises(ValueError, np.dtype,
{'formats': ['i4', 'i4'], 'f0': ('i4', 0), 'f1':('i4', 4)})
-
+
class TestSubarray(object):
def test_single_subarray(self):
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index 878ef2ac9..14413224e 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -435,7 +435,7 @@ class TestRecord(object):
def test_missing_field(self):
# https://github.com/numpy/numpy/issues/4806
arr = np.zeros((3,), dtype=[('x', int), ('y', int)])
- assert_raises(ValueError, lambda: arr[['nofield']])
+ assert_raises(KeyError, lambda: arr[['nofield']])
def test_fromarrays_nested_structured_arrays(self):
arrays = [