summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-06-21 19:54:34 -0500
committerMark Wiebe <mwiebe@enthought.com>2011-06-21 19:54:34 -0500
commit81f156ae5da85f31909e0933b28f29b2d4c48ab7 (patch)
tree2a9a3e7de0d9a2c63851195a11a06537fa4f75da
parentf57f1ef6e86ca85586fb566a16e5d4cdd7a35066 (diff)
downloadnumpy-81f156ae5da85f31909e0933b28f29b2d4c48ab7.tar.gz
ENH: dtype: Allow unions and out-of-order fields
This will require validation that object dtypes don't overlap with other fields in a follow-on commit.
-rw-r--r--numpy/core/src/multiarray/descriptor.c68
-rw-r--r--numpy/core/tests/test_dtype.py25
2 files changed, 68 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index c97f8e4cc..a0c34e3a2 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -361,14 +361,25 @@ _convert_from_array_descr(PyObject *obj, int align)
/* Process rest */
if (PyTuple_GET_SIZE(item) == 2) {
- ret = PyArray_DescrConverter(PyTuple_GET_ITEM(item, 1), &conv);
+ if (align) {
+ ret = PyArray_DescrAlignConverter(PyTuple_GET_ITEM(item, 1),
+ &conv);
+ }
+ else {
+ ret = PyArray_DescrConverter(PyTuple_GET_ITEM(item, 1), &conv);
+ }
if (ret == PY_FAIL) {
PyObject_Print(PyTuple_GET_ITEM(item, 1), stderr, 0);
}
}
else if (PyTuple_GET_SIZE(item) == 3) {
newobj = PyTuple_GetSlice(item, 1, 3);
- ret = PyArray_DescrConverter(newobj, &conv);
+ if (align) {
+ ret = PyArray_DescrAlignConverter(newobj, &conv);
+ }
+ else {
+ ret = PyArray_DescrConverter(newobj, &conv);
+ }
Py_DECREF(newobj);
}
else {
@@ -505,7 +516,12 @@ _convert_from_list(PyObject *obj, int align)
for (i = 0; i < n; i++) {
tup = PyTuple_New(2);
key = PyUString_FromFormat("f%d", i);
- ret = PyArray_DescrConverter(PyList_GET_ITEM(obj, i), &conv);
+ if (align) {
+ ret = PyArray_DescrAlignConverter(PyList_GET_ITEM(obj, i), &conv);
+ }
+ else {
+ ret = PyArray_DescrConverter(PyList_GET_ITEM(obj, i), &conv);
+ }
if (ret == PY_FAIL) {
Py_DECREF(tup);
Py_DECREF(key);
@@ -585,8 +601,10 @@ _convert_from_commastring(PyObject *obj, int align)
return NULL;
}
if (PyList_GET_SIZE(listobj) == 1) {
- if (PyArray_DescrConverter(
- PyList_GET_ITEM(listobj, 0), &res) == NPY_FAIL) {
+ int retcode;
+ retcode = PyArray_DescrConverter(PyList_GET_ITEM(listobj, 0),
+ &res);
+ if (retcode == NPY_FAIL) {
res = NULL;
}
}
@@ -760,21 +778,21 @@ _convert_from_dict(PyObject *obj, int align)
totalsize = 0;
for (i = 0; i < n; i++) {
- PyObject *tup, *descr, *index, *item, *name, *off;
+ PyObject *tup, *descr, *index, *title, *name, *off;
int len, ret, _align = 1;
PyArray_Descr *newdescr;
/* Build item to insert (descr, offset, [title])*/
len = 2;
- item = NULL;
+ title = NULL;
index = PyInt_FromLong(i);
if (titles) {
- item=PyObject_GetItem(titles, index);
- if (item && item != Py_None) {
+ title=PyObject_GetItem(titles, index);
+ if (title && title != Py_None) {
len = 3;
}
else {
- Py_XDECREF(item);
+ Py_XDECREF(title);
}
PyErr_Clear();
}
@@ -783,7 +801,12 @@ _convert_from_dict(PyObject *obj, int align)
if (!descr) {
goto fail;
}
- ret = PyArray_DescrConverter(descr, &newdescr);
+ if (align) {
+ ret = PyArray_DescrAlignConverter(descr, &newdescr);
+ }
+ else {
+ ret = PyArray_DescrConverter(descr, &newdescr);
+ }
Py_DECREF(descr);
if (ret == PY_FAIL) {
Py_DECREF(tup);
@@ -812,13 +835,8 @@ _convert_from_dict(PyObject *obj, int align)
(int)offset, (int)newdescr->alignment);
ret = PY_FAIL;
}
- else if (offset < totalsize) {
- PyErr_SetString(PyExc_ValueError,
- "invalid offset (must be ordered)");
- ret = PY_FAIL;
- }
- else if (offset > totalsize) {
- totalsize = offset;
+ else if (offset + newdescr->elsize > totalsize) {
+ totalsize = offset + newdescr->elsize;
}
}
else {
@@ -827,9 +845,10 @@ _convert_from_dict(PyObject *obj, int align)
totalsize = (totalsize + _align - 1) & (-_align);
}
PyTuple_SET_ITEM(tup, 1, PyInt_FromLong(totalsize));
+ totalsize += newdescr->elsize;
}
if (len == 3) {
- PyTuple_SET_ITEM(tup, 2, item);
+ PyTuple_SET_ITEM(tup, 2, title);
}
name = PyObject_GetItem(names, index);
if (!name) {
@@ -856,16 +875,16 @@ _convert_from_dict(PyObject *obj, int align)
Py_DECREF(name);
if (len == 3) {
#if defined(NPY_PY3K)
- if (PyUString_Check(item)) {
+ if (PyUString_Check(title)) {
#else
- if (PyUString_Check(item) || PyUnicode_Check(item)) {
+ if (PyUString_Check(title) || PyUnicode_Check(title)) {
#endif
- if (PyDict_GetItem(fields, item) != NULL) {
+ if (PyDict_GetItem(fields, title) != NULL) {
PyErr_SetString(PyExc_ValueError,
"title already used as a name or title.");
ret=PY_FAIL;
}
- PyDict_SetItem(fields, item, tup);
+ PyDict_SetItem(fields, title, tup);
}
}
Py_DECREF(tup);
@@ -873,10 +892,9 @@ _convert_from_dict(PyObject *obj, int align)
goto fail;
}
dtypeflags |= (newdescr->flags & NPY_FROM_FIELDS);
- totalsize += newdescr->elsize;
}
- new = PyArray_DescrNewFromType(PyArray_VOID);
+ new = PyArray_DescrNewFromType(NPY_VOID);
if (new == NULL) {
goto fail;
}
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 740100d22..18f0a973b 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -98,6 +98,31 @@ class TestRecord(TestCase):
dt = np.dtype({'f0': ('i4', 0), 'f1':('u1', 4)}, align=True)
assert_equal(dt.itemsize, 8)
+ def test_union_struct(self):
+ # Should be able to create union dtypes
+ dt = np.dtype({'names':['f0','f1','f2'], 'formats':['<u4', '<u2','<u2'],
+ 'offsets':[0,0,2]}, align=True)
+ assert_equal(dt.itemsize, 4)
+ a = np.array([3], dtype='<u4').view(dt)
+ a['f1'] = 10
+ a['f2'] = 36
+ assert_equal(a['f0'], 10 + 36*256*256)
+ # Should be able to specify fields out of order
+ dt = np.dtype({'names':['f0','f1','f2'], 'formats':['<u4', '<u2','<u2'],
+ 'offsets':[4,0,2]}, align=True)
+ assert_equal(dt.itemsize, 8)
+ dt2 = np.dtype({'names':['f2','f0','f1'],
+ 'formats':['<u2', '<u4','<u2'],
+ 'offsets':[2,4,0]}, align=True)
+ vals = [(0,1,2), (3,-1,4)]
+ vals2 = [(2,0,1), (4,3,-1)]
+ a = np.array(vals, dt)
+ b = np.array(vals2, dt2)
+ assert_equal(a.astype(dt2), b)
+ assert_equal(b.astype(dt), a)
+ assert_equal(a.view(dt2), b)
+ assert_equal(b.view(dt), a)
+
class TestSubarray(TestCase):
def test_single_subarray(self):
a = np.dtype((np.int, (2)))