summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/_internal.py70
-rw-r--r--numpy/core/src/multiarray/descriptor.c2
-rw-r--r--numpy/core/src/multiarray/item_selection.c17
-rw-r--r--numpy/core/tests/test_dtype.py6
4 files changed, 41 insertions, 54 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 99e64d475..05ba45ad9 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -129,60 +129,38 @@ def _reconstruct(subtype, shape, dtype):
return ndarray.__new__(subtype, shape, dtype)
-# format_re and _split were taken from numarray by J. Todd Miller
+# format_re was originally from numarray by J. Todd Miller
-def _split(input):
- """Split the input formats string into field formats without splitting
- the tuple used to specify multi-dimensional arrays."""
-
- newlist = []
- hold = asbytes('')
-
- listinput = input.split(asbytes(','))
- for element in listinput:
- if hold != asbytes(''):
- item = hold + asbytes(',') + element
- else:
- item = element
- left = item.count(asbytes('('))
- right = item.count(asbytes(')'))
-
- # if the parenthesis is not balanced, hold the string
- if left > right :
- hold = item
-
- # when balanced, append to the output list and reset the hold
- elif left == right:
- newlist.append(item.strip())
- hold = asbytes('')
-
- # too many close parenthesis is unacceptable
- else:
- raise SyntaxError(item)
-
- # if there is string left over in hold
- if hold != asbytes(''):
- raise SyntaxError(hold)
-
- return newlist
-
-format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*)'))
+format_re = re.compile(asbytes(r'(?P<order1>[<>|=]?)(?P<repeats> *[(]?[ ,0-9]*[)]? *)(?P<order2>[<>|=]?)(?P<dtype>[A-Za-z0-9.]*(?:\[[a-zA-Z0-9,.]+\])?)'))
+sep_re = re.compile(asbytes(r'\s*,\s*'))
+space_re = re.compile(asbytes(r'\s+$'))
# astr is a string (perhaps comma separated)
_convorder = {asbytes('='): _nbo}
def _commastring(astr):
- res = _split(astr)
- if (len(res)) < 1:
- raise ValueError("unrecognized formant")
+ startindex = 0
result = []
- for k,item in enumerate(res):
- # convert item
+ while startindex < len(astr):
+ mo = format_re.match(astr, pos=startindex)
try:
- (order1, repeats, order2, dtype) = format_re.match(item).groups()
+ (order1, repeats, order2, dtype) = mo.groups()
except (TypeError, AttributeError):
- raise ValueError('format %s is not recognized' % item)
+ raise ValueError('format number %d of "%s" is not recognized' %
+ (len(result)+1, astr))
+ startindex = mo.end()
+ # Separator or ending padding
+ if startindex < len(astr):
+ if space_re.match(astr, pos=startindex):
+ startindex = len(astr)
+ else:
+ mo = sep_re.match(astr, pos=startindex)
+ if not mo:
+ raise ValueError(
+ 'format number %d of "%s" is not recognized' %
+ (len(result)+1, astr))
+ startindex = mo.end()
if order2 == asbytes(''):
order = order1
@@ -192,7 +170,7 @@ def _commastring(astr):
order1 = _convorder.get(order1, order1)
order2 = _convorder.get(order2, order2)
if (order1 != order2):
- raise ValueError('in-consistent byte-order specification %s and %s' % (order1, order2))
+ raise ValueError('inconsistent byte-order specification %s and %s' % (order1, order2))
order = order1
if order in [asbytes('|'), asbytes('='), _nbo]:
@@ -203,7 +181,7 @@ def _commastring(astr):
else:
newitem = (dtype, eval(repeats))
result.append(newitem)
-
+
return result
def _getintp_ctype():
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 8f3038bf2..f9327ee54 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -599,7 +599,7 @@ _convert_from_commastring(PyObject *obj, int align)
}
listobj = PyObject_CallMethod(_numpy_internal, "_commastring", "O", obj);
Py_DECREF(_numpy_internal);
- if (!listobj) {
+ if (listobj == NULL) {
return NULL;
}
if (!PyList_Check(listobj) || PyList_GET_SIZE(listobj) < 1) {
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index c4154501e..93788b123 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -27,6 +27,7 @@ NPY_NO_EXPORT PyObject *
PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
PyArrayObject *out, NPY_CLIPMODE clipmode)
{
+ PyArray_Descr *dtype;
PyArray_FastTakeFunc *func;
PyArrayObject *obj, *self, *indices;
intp nd, i, j, n, m, max_item, tmp, chunk, nelem;
@@ -64,9 +65,10 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
}
}
if (!out) {
- Py_INCREF(PyArray_DESCR(self));
+ dtype = PyArray_DESCR(self);
+ Py_INCREF(dtype);
obj = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self),
- PyArray_DESCR(self),
+ dtype,
nd, shape,
NULL, NULL, 0,
(PyObject *)self);
@@ -93,9 +95,9 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
*/
flags |= NPY_ARRAY_ENSURECOPY;
}
- Py_INCREF(PyArray_DESCR(self));
- obj = (PyArrayObject *)PyArray_FromArray(out, PyArray_DESCR(self),
- flags);
+ dtype = PyArray_DESCR(self);
+ Py_INCREF(dtype);
+ obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags);
if (obj == NULL) {
goto fail;
}
@@ -175,12 +177,13 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
}
}
+ PyArray_INCREF(obj);
Py_XDECREF(indices);
Py_XDECREF(self);
if (out != NULL && out != obj) {
+ Py_INCREF(out);
Py_DECREF(obj);
obj = out;
- Py_INCREF(obj);
}
return (PyObject *)obj;
@@ -712,9 +715,9 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
Py_DECREF(ap);
PyDataMem_FREE(mps);
if (out != NULL && out != obj) {
+ Py_INCREF(out);
Py_DECREF(obj);
obj = out;
- Py_INCREF(obj);
}
return (PyObject *)obj;
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index ff13dcad6..0e1bfe182 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -185,6 +185,12 @@ class TestRecord(TestCase):
'formats':['i1', 'O'],
'offsets':[np.dtype('intp').itemsize, 0]})
+ def test_comma_datetime(self):
+ dt = np.dtype('M8[D],datetime64[Y],i8')
+ assert_equal(dt, np.dtype([('f0', 'M8[D]'),
+ ('f1', 'datetime64[Y]'),
+ ('f2', 'i8')]))
+
class TestSubarray(TestCase):
def test_single_subarray(self):
a = np.dtype((np.int, (2)))