summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime <jaime.frio@gmail.com>2015-06-01 09:36:33 -0700
committerJaime <jaime.frio@gmail.com>2015-06-01 09:36:33 -0700
commit9e7a0b224b4a7859efc1bc80b7ccc5c86c83d3df (patch)
tree793e54ea9b353e18b45dcb00066aa5d456999af8
parent647d2a6586a1f98af1b0ceee0ed0552513625dea (diff)
parent1acd14c3ff6d781801fab7b550fc9d295325f7da (diff)
downloadnumpy-9e7a0b224b4a7859efc1bc80b7ccc5c86c83d3df.tar.gz
Merge pull request #5920 from embray/descr-fields-dictproxy
ENH: Allow dictproxy objects to be used in place of dicts when creating a dtype
-rw-r--r--numpy/core/src/multiarray/descriptor.c51
-rw-r--r--numpy/core/tests/test_dtype.py8
2 files changed, 40 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 2bb45a6e0..13e172a0e 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -29,6 +29,8 @@
#define NPY_NEXT_ALIGNED_OFFSET(offset, alignment) \
(((offset) + (alignment) - 1) & (-(alignment)))
+#define PyDictProxy_Check(obj) (Py_TYPE(obj) == &PyDictProxy_Type)
+
static PyObject *typeDict = NULL; /* Must be explicitly loaded */
static PyArray_Descr *
@@ -270,7 +272,7 @@ _convert_from_tuple(PyObject *obj)
type->elsize = itemsize;
}
}
- else if (PyDict_Check(val)) {
+ else if (PyDict_Check(val) || PyDictProxy_Check(val)) {
/* Assume it's a metadata dictionary */
if (PyDict_Merge(type->metadata, val, 0) == -1) {
Py_DECREF(type);
@@ -944,15 +946,21 @@ _convert_from_dict(PyObject *obj, int align)
if (fields == NULL) {
return (PyArray_Descr *)PyErr_NoMemory();
}
- names = PyDict_GetItemString(obj, "names");
- descrs = PyDict_GetItemString(obj, "formats");
+ /* Use PyMapping_GetItemString to support dictproxy objects as well */
+ names = PyMapping_GetItemString(obj, "names");
+ descrs = PyMapping_GetItemString(obj, "formats");
if (!names || !descrs) {
Py_DECREF(fields);
+ PyErr_Clear();
return _use_fields_dict(obj, align);
}
n = PyObject_Length(names);
- offsets = PyDict_GetItemString(obj, "offsets");
- titles = PyDict_GetItemString(obj, "titles");
+ offsets = PyMapping_GetItemString(obj, "offsets");
+ titles = PyMapping_GetItemString(obj, "titles");
+ if (!offsets || !titles) {
+ PyErr_Clear();
+ }
+
if ((n > PyObject_Length(descrs))
|| (offsets && (n > PyObject_Length(offsets)))
|| (titles && (n > PyObject_Length(titles)))) {
@@ -966,8 +974,10 @@ _convert_from_dict(PyObject *obj, int align)
* If a property 'aligned' is in the dict, it overrides the align flag
* to be True if it not already true.
*/
- tmp = PyDict_GetItemString(obj, "aligned");
- if (tmp != NULL) {
+ tmp = PyMapping_GetItemString(obj, "aligned");
+ if (tmp == NULL) {
+ PyErr_Clear();
+ } else {
if (tmp == Py_True) {
align = 1;
}
@@ -1138,8 +1148,10 @@ _convert_from_dict(PyObject *obj, int align)
}
/* Override the itemsize if provided */
- tmp = PyDict_GetItemString(obj, "itemsize");
- if (tmp != NULL) {
+ tmp = PyMapping_GetItemString(obj, "itemsize");
+ if (tmp == NULL) {
+ PyErr_Clear();
+ } else {
itemsize = (int)PyInt_AsLong(tmp);
if (itemsize == -1 && PyErr_Occurred()) {
Py_DECREF(new);
@@ -1168,17 +1180,18 @@ _convert_from_dict(PyObject *obj, int align)
}
/* Add the metadata if provided */
- metadata = PyDict_GetItemString(obj, "metadata");
+ metadata = PyMapping_GetItemString(obj, "metadata");
- if (new->metadata == NULL) {
+ if (metadata == NULL) {
+ PyErr_Clear();
+ }
+ else if (new->metadata == NULL) {
new->metadata = metadata;
Py_XINCREF(new->metadata);
}
- else if (metadata != NULL) {
- if (PyDict_Merge(new->metadata, metadata, 0) == -1) {
- Py_DECREF(new);
- return NULL;
- }
+ else if (PyDict_Merge(new->metadata, metadata, 0) == -1) {
+ Py_DECREF(new);
+ return NULL;
}
return new;
@@ -1446,7 +1459,7 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at)
}
return NPY_SUCCEED;
}
- else if (PyDict_Check(obj)) {
+ else if (PyDict_Check(obj) || PyDictProxy_Check(obj)) {
/* or a dictionary */
*at = _convert_from_dict(obj,0);
if (*at == NULL) {
@@ -2741,7 +2754,7 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
NPY_NO_EXPORT int
PyArray_DescrAlignConverter(PyObject *obj, PyArray_Descr **at)
{
- if (PyDict_Check(obj)) {
+ if (PyDict_Check(obj) || PyDictProxy_Check(obj)) {
*at = _convert_from_dict(obj, 1);
}
else if (PyBytes_Check(obj)) {
@@ -2777,7 +2790,7 @@ PyArray_DescrAlignConverter(PyObject *obj, PyArray_Descr **at)
NPY_NO_EXPORT int
PyArray_DescrAlignConverter2(PyObject *obj, PyArray_Descr **at)
{
- if (PyDict_Check(obj)) {
+ if (PyDict_Check(obj) || PyDictProxy_Check(obj)) {
*at = _convert_from_dict(obj, 1);
}
else if (PyBytes_Check(obj)) {
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 9040c262b..b293cdbbc 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -245,6 +245,14 @@ class TestRecord(TestCase):
('f1', 'datetime64[Y]'),
('f2', 'i8')]))
+ def test_from_dictproxy(self):
+ # Tests for PR #5920
+ dt = np.dtype({'names': ['a', 'b'], 'formats': ['i4', 'f4']})
+ assert_dtype_equal(dt, np.dtype(dt.fields))
+ dt2 = np.dtype((np.void, dt.fields))
+ assert_equal(dt2.fields, dt.fields)
+
+
class TestSubarray(TestCase):
def test_single_subarray(self):
a = np.dtype((np.int, (2)))