diff options
author | Mark Wiebe <mwiebe@enthought.com> | 2011-07-18 18:48:44 -0500 |
---|---|---|
committer | Mark Wiebe <mwiebe@enthought.com> | 2011-07-19 14:00:28 -0500 |
commit | f8ff0e326c08e42fc9cb230b56079d05bd1059bc (patch) | |
tree | 2e15ef12c3c57c3e901994ed8f3965d7ac4af6de | |
parent | c9f0f296bea82af0cf475d21b44f10ebd9e99268 (diff) | |
download | numpy-f8ff0e326c08e42fc9cb230b56079d05bd1059bc.tar.gz |
BUG: dtype: comma-list dtype formats didn't accept M8[] parameterized dtypes
-rw-r--r-- | numpy/core/_internal.py | 70 | ||||
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 6 |
4 files changed, 41 insertions, 54 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 99e64d475..05ba45ad9 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -129,60 +129,38 @@ def _reconstruct(subtype, shape, dtype): return ndarray.__new__(subtype, shape, dtype) -# format_re and _split were taken from numarray by J. Todd Miller +# format_re was originally from numarray by J. Todd Miller -def _split(input): - """Split the input formats string into field formats without splitting - the tuple used to specify multi-dimensional arrays.""" - - newlist = [] - hold = asbytes('') - - listinput = input.split(asbytes(',')) - for element in listinput: - if hold != asbytes(''): - item = hold + asbytes(',') + element - else: - item = element - left = item.count(asbytes('(')) - right = item.count(asbytes(')')) - - # if the parenthesis is not balanced, hold the string - if left > right : - hold = item - - # when balanced, append to the output list and reset the hold - elif left == right: - newlist.append(item.strip()) - hold = asbytes('') - - # too many close parenthesis is unacceptable - else: - raise SyntaxError(item) - - # if there is string left over in hold - if hold != asbytes(''): - raise SyntaxError(hold) - - return newlist - -format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*)')) +format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*(?:\[[a-zA-Z0-9,.]+\])?)')) +sep_re = re.compile(asbytes(r'\s*,\s*')) +space_re = re.compile(asbytes(r'\s+$')) # astr is a string (perhaps comma separated) _convorder = {asbytes('='): _nbo} def _commastring(astr): - res = _split(astr) - if (len(res)) < 1: - raise ValueError("unrecognized formant") + startindex = 0 result = [] - for k,item in enumerate(res): - # convert item + while startindex < len(astr): + mo = format_re.match(astr, pos=startindex) try: - (order1, repeats, order2, dtype) = format_re.match(item).groups() + (order1, repeats, order2, dtype) = mo.groups() except (TypeError, AttributeError): - raise ValueError('format %s is not recognized' % item) + raise ValueError('format number %d of "%s" is not recognized' % + (len(result)+1, astr)) + startindex = mo.end() + # Separator or ending padding + if startindex < len(astr): + if space_re.match(astr, pos=startindex): + startindex = len(astr) + else: + mo = sep_re.match(astr, pos=startindex) + if not mo: + raise ValueError( + 'format number %d of "%s" is not recognized' % + (len(result)+1, astr)) + startindex = mo.end() if order2 == asbytes(''): order = order1 @@ -192,7 +170,7 @@ def _commastring(astr): order1 = _convorder.get(order1, order1) order2 = _convorder.get(order2, order2) if (order1 != order2): - raise ValueError('in-consistent byte-order specification %s and %s' % (order1, order2)) + raise ValueError('inconsistent byte-order specification %s and %s' % (order1, order2)) order = order1 if order in [asbytes('|'), asbytes('='), _nbo]: @@ -203,7 +181,7 @@ def _commastring(astr): else: newitem = (dtype, eval(repeats)) result.append(newitem) - + return result def _getintp_ctype(): diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 8f3038bf2..f9327ee54 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -599,7 +599,7 @@ _convert_from_commastring(PyObject *obj, int align) } listobj = PyObject_CallMethod(_numpy_internal, "_commastring", "O", obj); Py_DECREF(_numpy_internal); - if (!listobj) { + if (listobj == NULL) { return NULL; } if (!PyList_Check(listobj) || PyList_GET_SIZE(listobj) < 1) { diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index c4154501e..93788b123 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -27,6 +27,7 @@ NPY_NO_EXPORT PyObject * PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, PyArrayObject *out, NPY_CLIPMODE clipmode) { + PyArray_Descr *dtype; PyArray_FastTakeFunc *func; PyArrayObject *obj, *self, *indices; intp nd, i, j, n, m, max_item, tmp, chunk, nelem; @@ -64,9 +65,10 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, } } if (!out) { - Py_INCREF(PyArray_DESCR(self)); + dtype = PyArray_DESCR(self); + Py_INCREF(dtype); obj = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self), - PyArray_DESCR(self), + dtype, nd, shape, NULL, NULL, 0, (PyObject *)self); @@ -93,9 +95,9 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, */ flags |= NPY_ARRAY_ENSURECOPY; } - Py_INCREF(PyArray_DESCR(self)); - obj = (PyArrayObject *)PyArray_FromArray(out, PyArray_DESCR(self), - flags); + dtype = PyArray_DESCR(self); + Py_INCREF(dtype); + obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags); if (obj == NULL) { goto fail; } @@ -175,12 +177,13 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, } } + PyArray_INCREF(obj); Py_XDECREF(indices); Py_XDECREF(self); if (out != NULL && out != obj) { + Py_INCREF(out); Py_DECREF(obj); obj = out; - Py_INCREF(obj); } return (PyObject *)obj; @@ -712,9 +715,9 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, Py_DECREF(ap); PyDataMem_FREE(mps); if (out != NULL && out != obj) { + Py_INCREF(out); Py_DECREF(obj); obj = out; - Py_INCREF(obj); } return (PyObject *)obj; diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index ff13dcad6..0e1bfe182 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -185,6 +185,12 @@ class TestRecord(TestCase): 'formats':['i1', 'O'], 'offsets':[np.dtype('intp').itemsize, 0]}) + def test_comma_datetime(self): + dt = np.dtype('M8[D],datetime64[Y],i8') + assert_equal(dt, np.dtype([('f0', 'M8[D]'), + ('f1', 'datetime64[Y]'), + ('f2', 'i8')])) + class TestSubarray(TestCase): def test_single_subarray(self): a = np.dtype((np.int, (2))) |