diff options
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 47 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 8 |
2 files changed, 32 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index c7eb5f87b..f9b5ff66f 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1590,9 +1590,9 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * @param arr contiguous sorted array to be searched. * @param key contiguous array of keys. * @param ret contiguous array of intp for returned indices. - * @return void + * @return int */ -static void +static int local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *sorter, PyArrayObject *ret) { @@ -1611,7 +1611,12 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, npy_intp imax = nelts; while (imin < imax) { npy_intp imid = imin + ((imax - imin) >> 1); - if (compare(parr + elsize*psorter[imid], pkey, key) < 0) { + npy_intp indx = psorter[imid]; + + if (indx < 0 || indx >= nelts) { + return -1; + } + if (compare(parr + elsize*indx, pkey, key) < 0) { imin = imid + 1; } else { @@ -1622,6 +1627,7 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, pret += 1; pkey += elsize; } + return 0; } @@ -1635,9 +1641,9 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, * @param arr contiguous sorted array to be searched. * @param key contiguous array of keys. * @param ret contiguous array of intp for returned indices. - * @return void + * @return int */ -static void +static int local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *sorter, PyArrayObject *ret) { @@ -1656,7 +1662,12 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, npy_intp imax = nelts; while (imin < imax) { npy_intp imid = imin + ((imax - imin) >> 1); - if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) { + npy_intp indx = psorter[imid]; + + if (indx < 0 || indx >= nelts) { + return -1; + } + if (compare(parr + elsize*indx, pkey, key) <= 0) { imin = imid + 1; } else { @@ -1667,6 +1678,7 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, pret += 1; pkey += elsize; } + return 0; } /*NUMPY_API @@ -1774,18 +1786,6 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, "sorter.size must equal a.size"); goto fail; } - max = PyArray_Max(sorter, 0, NULL); - if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) { - PyErr_SetString(PyExc_ValueError, - "sorter.max() must be less than a.size"); - goto fail; - } - min = PyArray_Min(sorter, 0, NULL); - if (PyLong_AsLong(min) < 0) { - PyErr_SetString(PyExc_ValueError, - "sorter elements must be non-negative"); - goto fail; - } } /* ret is a contiguous array of intp type to hold returned indices */ @@ -1809,16 +1809,23 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, } } else { + int err; + if (side == NPY_SEARCHLEFT) { NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_argsearch_left(ap1, ap2, sorter, ret); + err = local_argsearch_left(ap1, ap2, sorter, ret); NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); } else if (side == NPY_SEARCHRIGHT) { NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_argsearch_right(ap1, ap2, sorter, ret); + err = local_argsearch_right(ap1, ap2, sorter, ret); NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); } + if (err < 0) { + PyErr_SetString(PyExc_ValueError, + "Sorter index out of range."); + goto fail; + } Py_DECREF(ap3); Py_DECREF(sorter); } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b3b09a7ea..b2b4bab2b 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -752,9 +752,11 @@ class TestMethods(TestCase): assert_raises(TypeError, np.searchsorted, a, 0, sorter=(1,(2,3))) assert_raises(TypeError, np.searchsorted, a, 0, sorter=[1.1]) assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4]) - assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,6]) - assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,5]) - assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,2,3,4]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,6]) + + # bounds check + assert_raises(ValueError, np.searchsorted, a, 4, sorter=[0,1,2,3,5]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,0,1,2,3]) a = np.random.rand(100) s = a.argsort() |