diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2020-01-05 22:40:23 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2020-01-11 11:18:01 +0000 |
commit | e83bd4658de5d7ffb6a460f68d2638c1b0a35068 (patch) | |
tree | bd714498dac9abc8f3a195946e3e08a4239da8d0 | |
parent | 4cd5a88c130894d3916c1cb174b1e73f10dac232 (diff) | |
download | numpy-e83bd4658de5d7ffb6a460f68d2638c1b0a35068.tar.gz |
MAINT: Work with unicode strings in `dtype('i8,i8')`
Right now, we convert `str` objects to `bytes`, and then work with those.
Since this is a human convenience API, the input really ought to be a string.
A future patch will suggest deprecating `dtype(b'i8,i8')`, but for now it will continue to work.
-rw-r--r-- | numpy/core/_internal.py | 28 | ||||
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 46 |
2 files changed, 36 insertions, 38 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 88cf10a38..105b5d24a 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -18,9 +18,9 @@ except ImportError: IS_PYPY = platform.python_implementation() == 'PyPy' if (sys.byteorder == 'little'): - _nbo = b'<' + _nbo = '<' else: - _nbo = b'>' + _nbo = '>' def _makenames_list(adict, align): allfields = [] @@ -143,16 +143,16 @@ def _reconstruct(subtype, shape, dtype): # format_re was originally from numarray by J. Todd Miller -format_re = re.compile(br'(?P<order1>[<>|=]?)' - br'(?P<repeats> *[(]?[ ,0-9]*[)]? *)' - br'(?P<order2>[<>|=]?)' - br'(?P<dtype>[A-Za-z0-9.?]*(?:\[[a-zA-Z0-9,.]+\])?)') -sep_re = re.compile(br'\s*,\s*') -space_re = re.compile(br'\s+$') +format_re = re.compile(r'(?P<order1>[<>|=]?)' + r'(?P<repeats> *[(]?[ ,0-9]*[)]? *)' + r'(?P<order2>[<>|=]?)' + r'(?P<dtype>[A-Za-z0-9.?]*(?:\[[a-zA-Z0-9,.]+\])?)') +sep_re = re.compile(r'\s*,\s*') +space_re = re.compile(r'\s+$') # astr is a string (perhaps comma separated) -_convorder = {b'=': _nbo} +_convorder = {'=': _nbo} def _commastring(astr): startindex = 0 @@ -177,9 +177,9 @@ def _commastring(astr): (len(result)+1, astr)) startindex = mo.end() - if order2 == b'': + if order2 == '': order = order1 - elif order1 == b'': + elif order1 == '': order = order2 else: order1 = _convorder.get(order1, order1) @@ -190,10 +190,10 @@ def _commastring(astr): (order1, order2)) order = order1 - if order in [b'|', b'=', _nbo]: - order = b'' + if order in ['|', '=', _nbo]: + order = '' dtype = order + dtype - if (repeats == b''): + if (repeats == ''): newitem = dtype else: newitem = (dtype, eval(repeats)) diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 4cf01c9fd..b17718048 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -607,8 +607,14 @@ _convert_from_list(PyObject *obj, int align) * can produce */ PyObject *last_item = PyList_GET_ITEM(obj, n-1); - if (PyBytes_Check(last_item) && PyBytes_GET_SIZE(last_item) == 0) { - n = n - 1; + if (PyUnicode_Check(last_item)) { + Py_ssize_t s = PySequence_Size(last_item); + if (s < 0) { + return NULL; + } + if (s == 0) { + n = n - 1; + } } if (n == 0) { PyErr_SetString(PyExc_ValueError, "Expected at least one field name"); @@ -709,7 +715,7 @@ _convert_from_commastring(PyObject *obj, int align) PyArray_Descr *res; PyObject *_numpy_internal; - if (!PyBytes_Check(obj)) { + if (!PyUnicode_Check(obj)) { _report_generic_error(); return NULL; } @@ -1390,7 +1396,7 @@ _convert_from_type(PyObject *obj) { static PyArray_Descr * -_convert_from_bytes(PyObject *obj, int align); +_convert_from_str(PyObject *obj, int align); static PyArray_Descr * _convert_from_any(PyObject *obj, int align) @@ -1408,24 +1414,23 @@ _convert_from_any(PyObject *obj, int align) return _convert_from_type(obj); } /* or a typecode string */ - else if (PyUnicode_Check(obj)) { - /* Allow unicode format strings: convert to bytes */ - PyObject *obj2 = PyUnicode_AsASCIIString(obj); + else if (PyBytes_Check(obj)) { + /* Allow bytes format strings: convert to unicode */ + PyObject *obj2 = PyUnicode_FromEncodedObject(obj, NULL, NULL); if (obj2 == NULL) { /* Convert the exception into a TypeError */ - PyObject *err = PyErr_Occurred(); - if (PyErr_GivenExceptionMatches(err, PyExc_UnicodeEncodeError)) { + if (PyErr_ExceptionMatches(PyExc_UnicodeDecodeError)) { PyErr_SetString(PyExc_TypeError, "data type not understood"); } return NULL; } - PyArray_Descr *ret = _convert_from_any(obj2, align); + PyArray_Descr *ret = _convert_from_str(obj2, align); Py_DECREF(obj2); return ret; } - else if (PyBytes_Check(obj)) { - return _convert_from_bytes(obj, align); + else if (PyUnicode_Check(obj)) { + return _convert_from_str(obj, align); } else if (PyTuple_Check(obj)) { /* or a tuple */ @@ -1489,12 +1494,12 @@ PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at) /** Convert a bytestring specification into a dtype */ static PyArray_Descr * -_convert_from_bytes(PyObject *obj, int align) +_convert_from_str(PyObject *obj, int align) { /* Check for a string typecode. */ - char *type = NULL; Py_ssize_t len = 0; - if (PyBytes_AsStringAndSize(obj, &type, &len) < 0) { + char const *type = PyUnicode_AsUTF8AndSize(obj, &len); + if (type == NULL) { return NULL; } @@ -1611,13 +1616,7 @@ _convert_from_bytes(PyObject *obj, int align) if (typeDict == NULL) { goto fail; } - PyObject *item = NULL; - PyObject *tmp = PyUnicode_FromEncodedObject(obj, "ascii", "strict"); - if (tmp == NULL) { - goto fail; - } - item = PyDict_GetItem(typeDict, tmp); - Py_DECREF(tmp); + PyObject *item = PyDict_GetItem(typeDict, obj); if (item == NULL) { goto fail; } @@ -1665,8 +1664,7 @@ _convert_from_bytes(PyObject *obj, int align) return ret; fail: - PyErr_Format(PyExc_TypeError, - "data type \"%s\" not understood", PyBytes_AS_STRING(obj)); + PyErr_Format(PyExc_TypeError, "data type %R not understood", obj); return NULL; } |