diff options
author | mattip <matti.picus@gmail.com> | 2019-03-29 13:49:45 +0300 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2019-03-29 13:49:45 +0300 |
commit | bea49b553fa07c89890188282811f107913e2105 (patch) | |
tree | 89d9dc51c34ee0784b32bc60a8b4a952456819d1 | |
parent | 44a788b9ed45b42d270590663f7cddb191b17f8d (diff) | |
download | numpy-bea49b553fa07c89890188282811f107913e2105.tar.gz |
BUG: fixes from review
-rw-r--r-- | numpy/core/src/multiarray/compiled_base.c | 85 |
1 files changed, 40 insertions, 45 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c index fb20d594b..aebab4e43 100644 --- a/numpy/core/src/multiarray/compiled_base.c +++ b/numpy/core/src/multiarray/compiled_base.c @@ -832,12 +832,48 @@ fail: return NULL; } - static const char *EMPTY_SEQUENCE_ERR_MSG = "indices must be integral: the provided " \ "empty sequence was inferred as float. Wrap it with " \ "'np.array(indices, dtype=np.intp)'"; static const char *NON_INTEGRAL_ERROR_MSG = "only int indices permitted"; + +/* Convert obj to an ndarray with integer dtype or fail */ +static PyArrayObject * +astype_anyint(PyObject *obj) { + PyArrayObject *ret; + + if (!PyArray_Check(obj)) { + /* prefer int dtype */ + PyArray_Descr *dtype_guess = NULL; + if (PyArray_DTypeFromObject(obj, NPY_MAXDIMS, &dtype_guess) < 0) { + return NULL; + } + if (dtype_guess == NULL) { + if (PySequence_Check(obj) && PySequence_Size(obj) == 0) { + PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG); + } + return NULL; + } + ret = (PyArrayObject*)PyArray_FromAny(obj, dtype_guess, 0, 0, 0, NULL); + if (ret == NULL) { + return NULL; + } + } + else { + ret = (PyArrayObject *)obj; + Py_INCREF(ret); + } + + if (!(PyArray_ISINTEGER(ret) || PyArray_ISBOOL(ret))) { + /* ensure dtype is int-based */ + PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG); + return NULL; + } + + return ret; +} + /* * Converts a Python sequence into 'count' PyArrayObjects * @@ -868,27 +904,9 @@ static int int_sequence_to_arrays(PyObject *seq, if (item == NULL) { goto fail; } - if (PyArray_DTypeFromObject(item, NPY_MAXDIMS, &dtype) < 0) { - Py_DECREF(item); - goto fail; - } - if (dtype == NULL) { - if (PySequence_Check(item) && PySequence_Size(item) == 0) { - PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG); - } - Py_DECREF(item); - goto fail; - } - if (!(PyDataType_ISINTEGER(dtype) || PyDataType_ISBOOL(dtype))) { - PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG); - Py_DECREF(dtype); - Py_DECREF(item); - goto fail; - } - op[i] = (PyArrayObject *)PyArray_FromAny(item, dtype, 0, 0, 0, NULL); + op[i] = astype_anyint(item); Py_DECREF(item); if (op[i] == NULL) { - Py_DECREF(dtype); goto fail; } } @@ -1245,31 +1263,8 @@ arr_unravel_index(PyObject *self, PyObject *args, PyObject *kwds) unravel_size = PyArray_MultiplyList(dimensions.ptr, dimensions.len); - if (!PyArray_Check(indices0)) { - /* prefer int dtype */ - PyArray_Descr *dtype_guess = NULL; - if (PyArray_DTypeFromObject(indices0, NPY_MAXDIMS, &dtype_guess) < 0) { - goto fail; - } - if (dtype_guess == NULL) { - if (PySequence_Check(indices0) && PySequence_Size(indices0) == 0) { - PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG); - } - goto fail; - } - indices = (PyArrayObject*)PyArray_FromAny(indices0, dtype_guess, 0, 0, 0, NULL); - if (indices == NULL) { - goto fail; - } - } - else { - indices = (PyArrayObject *)indices0; - Py_INCREF(indices); - } - - if (!(PyArray_ISINTEGER(indices) || PyArray_ISBOOL(indices))) { - /* ensure dtype is int-based */ - PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG); + indices = astype_anyint(indices0); + if (indices == NULL) { goto fail; } |