summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/arrayobject.c30
1 files changed, 29 insertions, 1 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 32d49eaf2..6c8a64d84 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2827,10 +2827,10 @@ array_subscript(PyArrayObject *self, PyObject *op)
int nd, fancy;
PyArrayObject *other;
PyArrayMapIterObject *mit;
+ PyObject *obj;
if (PyString_Check(op) || PyUnicode_Check(op)) {
if (self->descr->names) {
- PyObject *obj;
obj = PyDict_GetItem(self->descr->fields, op);
if (obj != NULL) {
PyArray_Descr *descr;
@@ -2852,6 +2852,34 @@ array_subscript(PyArrayObject *self, PyObject *op)
return NULL;
}
+ /* Check for multiple field access
+ */
+ if (self->descr->names && PySequence_Check(op) && !PyTuple_Check(op)) {
+ int seqlen, i;
+ seqlen = PySequence_Size(op);
+ for (i=0; i<seqlen; i++) {
+ obj = PySequence_GetItem(op, i);
+ if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
+ Py_DECREF(obj);
+ break;
+ }
+ Py_DECREF(obj);
+ }
+ /* extract multiple fields if all elements in sequence
+ are either string or unicode (i.e. no break occurred).
+ */
+ fancy = ((seqlen > 0) && (i == seqlen));
+ if (fancy) {
+ PyObject *_numpy_internal;
+ _numpy_internal = PyImport_ImportModule("numpy.core._internal");
+ if (_numpy_internal == NULL) return NULL;
+ obj = PyObject_CallMethod(_numpy_internal, "_index_fields",
+ "OO", self, op);
+ Py_DECREF(_numpy_internal);
+ return obj;
+ }
+ }
+
if (op == Py_Ellipsis) {
Py_INCREF(self);
return (PyObject *)self;