summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/item_selection.c47
-rw-r--r--numpy/core/tests/test_multiarray.py8
2 files changed, 32 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index c7eb5f87b..f9b5ff66f 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1590,9 +1590,9 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
* @param arr contiguous sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
- * @return void
+ * @return int
*/
-static void
+static int
local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
@@ -1611,7 +1611,12 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
- if (compare(parr + elsize*psorter[imid], pkey, key) < 0) {
+ npy_intp indx = psorter[imid];
+
+ if (indx < 0 || indx >= nelts) {
+ return -1;
+ }
+ if (compare(parr + elsize*indx, pkey, key) < 0) {
imin = imid + 1;
}
else {
@@ -1622,6 +1627,7 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
pret += 1;
pkey += elsize;
}
+ return 0;
}
@@ -1635,9 +1641,9 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
* @param arr contiguous sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
- * @return void
+ * @return int
*/
-static void
+static int
local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
@@ -1656,7 +1662,12 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
- if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) {
+ npy_intp indx = psorter[imid];
+
+ if (indx < 0 || indx >= nelts) {
+ return -1;
+ }
+ if (compare(parr + elsize*indx, pkey, key) <= 0) {
imin = imid + 1;
}
else {
@@ -1667,6 +1678,7 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
pret += 1;
pkey += elsize;
}
+ return 0;
}
/*NUMPY_API
@@ -1774,18 +1786,6 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
"sorter.size must equal a.size");
goto fail;
}
- max = PyArray_Max(sorter, 0, NULL);
- if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) {
- PyErr_SetString(PyExc_ValueError,
- "sorter.max() must be less than a.size");
- goto fail;
- }
- min = PyArray_Min(sorter, 0, NULL);
- if (PyLong_AsLong(min) < 0) {
- PyErr_SetString(PyExc_ValueError,
- "sorter elements must be non-negative");
- goto fail;
- }
}
/* ret is a contiguous array of intp type to hold returned indices */
@@ -1809,16 +1809,23 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
}
}
else {
+ int err;
+
if (side == NPY_SEARCHLEFT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_argsearch_left(ap1, ap2, sorter, ret);
+ err = local_argsearch_left(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else if (side == NPY_SEARCHRIGHT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_argsearch_right(ap1, ap2, sorter, ret);
+ err = local_argsearch_right(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
+ if (err < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "Sorter index out of range.");
+ goto fail;
+ }
Py_DECREF(ap3);
Py_DECREF(sorter);
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b3b09a7ea..b2b4bab2b 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -752,9 +752,11 @@ class TestMethods(TestCase):
assert_raises(TypeError, np.searchsorted, a, 0, sorter=(1,(2,3)))
assert_raises(TypeError, np.searchsorted, a, 0, sorter=[1.1])
assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4])
- assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,6])
- assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,5])
- assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,2,3,4])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,6])
+
+ # bounds check
+ assert_raises(ValueError, np.searchsorted, a, 4, sorter=[0,1,2,3,5])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,0,1,2,3])
a = np.random.rand(100)
s = a.argsort()