summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2019-03-29 13:49:45 +0300
committermattip <matti.picus@gmail.com>2019-03-29 13:49:45 +0300
commitbea49b553fa07c89890188282811f107913e2105 (patch)
tree89d9dc51c34ee0784b32bc60a8b4a952456819d1
parent44a788b9ed45b42d270590663f7cddb191b17f8d (diff)
downloadnumpy-bea49b553fa07c89890188282811f107913e2105.tar.gz
BUG: fixes from review
-rw-r--r--numpy/core/src/multiarray/compiled_base.c85
1 files changed, 40 insertions, 45 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index fb20d594b..aebab4e43 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -832,12 +832,48 @@ fail:
return NULL;
}
-
static const char *EMPTY_SEQUENCE_ERR_MSG = "indices must be integral: the provided " \
"empty sequence was inferred as float. Wrap it with " \
"'np.array(indices, dtype=np.intp)'";
static const char *NON_INTEGRAL_ERROR_MSG = "only int indices permitted";
+
+/* Convert obj to an ndarray with integer dtype or fail */
+static PyArrayObject *
+astype_anyint(PyObject *obj) {
+ PyArrayObject *ret;
+
+ if (!PyArray_Check(obj)) {
+ /* prefer int dtype */
+ PyArray_Descr *dtype_guess = NULL;
+ if (PyArray_DTypeFromObject(obj, NPY_MAXDIMS, &dtype_guess) < 0) {
+ return NULL;
+ }
+ if (dtype_guess == NULL) {
+ if (PySequence_Check(obj) && PySequence_Size(obj) == 0) {
+ PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG);
+ }
+ return NULL;
+ }
+ ret = (PyArrayObject*)PyArray_FromAny(obj, dtype_guess, 0, 0, 0, NULL);
+ if (ret == NULL) {
+ return NULL;
+ }
+ }
+ else {
+ ret = (PyArrayObject *)obj;
+ Py_INCREF(ret);
+ }
+
+ if (!(PyArray_ISINTEGER(ret) || PyArray_ISBOOL(ret))) {
+ /* ensure dtype is int-based */
+ PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG);
+ return NULL;
+ }
+
+ return ret;
+}
+
/*
* Converts a Python sequence into 'count' PyArrayObjects
*
@@ -868,27 +904,9 @@ static int int_sequence_to_arrays(PyObject *seq,
if (item == NULL) {
goto fail;
}
- if (PyArray_DTypeFromObject(item, NPY_MAXDIMS, &dtype) < 0) {
- Py_DECREF(item);
- goto fail;
- }
- if (dtype == NULL) {
- if (PySequence_Check(item) && PySequence_Size(item) == 0) {
- PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG);
- }
- Py_DECREF(item);
- goto fail;
- }
- if (!(PyDataType_ISINTEGER(dtype) || PyDataType_ISBOOL(dtype))) {
- PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG);
- Py_DECREF(dtype);
- Py_DECREF(item);
- goto fail;
- }
- op[i] = (PyArrayObject *)PyArray_FromAny(item, dtype, 0, 0, 0, NULL);
+ op[i] = astype_anyint(item);
Py_DECREF(item);
if (op[i] == NULL) {
- Py_DECREF(dtype);
goto fail;
}
}
@@ -1245,31 +1263,8 @@ arr_unravel_index(PyObject *self, PyObject *args, PyObject *kwds)
unravel_size = PyArray_MultiplyList(dimensions.ptr, dimensions.len);
- if (!PyArray_Check(indices0)) {
- /* prefer int dtype */
- PyArray_Descr *dtype_guess = NULL;
- if (PyArray_DTypeFromObject(indices0, NPY_MAXDIMS, &dtype_guess) < 0) {
- goto fail;
- }
- if (dtype_guess == NULL) {
- if (PySequence_Check(indices0) && PySequence_Size(indices0) == 0) {
- PyErr_SetString(PyExc_TypeError, EMPTY_SEQUENCE_ERR_MSG);
- }
- goto fail;
- }
- indices = (PyArrayObject*)PyArray_FromAny(indices0, dtype_guess, 0, 0, 0, NULL);
- if (indices == NULL) {
- goto fail;
- }
- }
- else {
- indices = (PyArrayObject *)indices0;
- Py_INCREF(indices);
- }
-
- if (!(PyArray_ISINTEGER(indices) || PyArray_ISBOOL(indices))) {
- /* ensure dtype is int-based */
- PyErr_SetString(PyExc_TypeError, NON_INTEGRAL_ERROR_MSG);
+ indices = astype_anyint(indices0);
+ if (indices == NULL) {
goto fail;
}