diff options
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 95 |
1 files changed, 33 insertions, 62 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 0063131e4..8654d7e3d 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -39,24 +39,29 @@ static PyObject *typeDict = NULL; /* Must be explicitly loaded */ +/* + * Generate a vague error message when a function returned NULL but forgot + * to set an exception. We should aim to remove this eventually. + */ +static void +_report_generic_error(void) { + PyErr_SetString(PyExc_TypeError, "data type not understood"); +} + static PyArray_Descr * _use_inherit(PyArray_Descr *type, PyObject *newobj, int *errflag); static PyArray_Descr * +_convert_from_any(PyObject *obj, int align); + +static PyArray_Descr * _arraydescr_run_converter(PyObject *arg, int align) { - PyArray_Descr *type = NULL; - if (align) { - if (PyArray_DescrAlignConverter(arg, &type) == NPY_FAIL) { - return NULL; - } - } - else { - if (PyArray_DescrConverter(arg, &type) == NPY_FAIL) { - return NULL; - } + PyArray_Descr *type = _convert_from_any(arg, align); + /* TODO: fix the `_convert` functions so that this can never happen */ + if (type == NULL && !PyErr_Occurred()) { + _report_generic_error(); } - assert(type != NULL); return type; } @@ -1399,20 +1404,12 @@ _convert_from_type(PyObject *obj) { } } -/* - * Generate a vague error message when a function returned NULL but forgot - * to set an exception. We should aim to remove this eventually. - */ -static void -_report_generic_error(void) { - PyErr_SetString(PyExc_TypeError, "data type not understood"); -} static PyArray_Descr * -_convert_from_bytes(PyObject *obj); +_convert_from_bytes(PyObject *obj, int align); static PyArray_Descr * -_convert_from_any(PyObject *obj) +_convert_from_any(PyObject *obj, int align) { /* default */ if (obj == Py_None) { @@ -1439,24 +1436,24 @@ _convert_from_any(PyObject *obj) } return NULL; } - PyArray_Descr *ret = _convert_from_any(obj2); + PyArray_Descr *ret = _convert_from_any(obj2, align); Py_DECREF(obj2); return ret; } else if (PyBytes_Check(obj)) { - return _convert_from_bytes(obj); + return _convert_from_bytes(obj, align); } else if (PyTuple_Check(obj)) { /* or a tuple */ - return _convert_from_tuple(obj, 0); + return _convert_from_tuple(obj, align); } else if (PyList_Check(obj)) { /* or a list */ - return _convert_from_array_descr(obj, 0); + return _convert_from_array_descr(obj, align); } else if (PyDict_Check(obj) || PyDictProxy_Check(obj)) { /* or a dictionary */ - return _convert_from_dict(obj, 0); + return _convert_from_dict(obj, align); } else if (PyArray_Check(obj)) { _report_generic_error(); @@ -1503,16 +1500,13 @@ _convert_from_any(PyObject *obj) NPY_NO_EXPORT int PyArray_DescrConverter(PyObject *obj, PyArray_Descr **at) { - *at = _convert_from_any(obj); - if (*at == NULL && !PyErr_Occurred()) { - _report_generic_error(); - } + *at = _arraydescr_run_converter(obj, 0); return (*at) ? NPY_SUCCEED : NPY_FAIL; } /** Convert a bytestring specification into a dtype */ static PyArray_Descr * -_convert_from_bytes(PyObject *obj) +_convert_from_bytes(PyObject *obj, int align) { /* Check for a string typecode. */ char *type = NULL; @@ -1528,7 +1522,7 @@ _convert_from_bytes(PyObject *obj) /* check for commas present or first (or second) element a digit */ if (_check_for_commastring(type, len)) { - return _convert_from_commastring(obj, 0); + return _convert_from_commastring(obj, align); } /* Process the endian character. '|' is replaced by '='*/ @@ -1661,7 +1655,11 @@ _convert_from_bytes(PyObject *obj) } } } - return _convert_from_any(item); + /* + * Probably only ever dispatches to `_convert_from_type`, but who + * knows what users are injecting into `np.typeDict`. + */ + return _convert_from_any(item, align); } if (PyDataType_ISUNSIZED(ret) && ret->elsize != elsize) { @@ -2858,35 +2856,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) NPY_NO_EXPORT int PyArray_DescrAlignConverter(PyObject *obj, PyArray_Descr **at) { - if (PyDict_Check(obj) || PyDictProxy_Check(obj)) { - *at = _convert_from_dict(obj, 1); - } - else if (PyBytes_Check(obj)) { - *at = _convert_from_commastring(obj, 1); - } - else if (PyUnicode_Check(obj)) { - PyObject *tmp; - tmp = PyUnicode_AsASCIIString(obj); - *at = _convert_from_commastring(tmp, 1); - Py_DECREF(tmp); - } - else if (PyTuple_Check(obj)) { - *at = _convert_from_tuple(obj, 1); - } - else if (PyList_Check(obj)) { - *at = _convert_from_array_descr(obj, 1); - } - else { - return PyArray_DescrConverter(obj, at); - } - if (*at == NULL) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "data-type-descriptor not understood"); - } - return NPY_FAIL; - } - return NPY_SUCCEED; + *at = _arraydescr_run_converter(obj, 1); + return (*at) ? NPY_SUCCEED : NPY_FAIL; } /*NUMPY_API |