summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/2.0.0-notes.rst6
-rw-r--r--numpy/add_newdocs.py2
-rw-r--r--numpy/core/fromnumeric.py25
-rw-r--r--numpy/core/src/multiarray/item_selection.c198
-rw-r--r--numpy/core/src/multiarray/methods.c13
-rw-r--r--numpy/core/tests/test_multiarray.py22
6 files changed, 232 insertions, 34 deletions
diff --git a/doc/release/2.0.0-notes.rst b/doc/release/2.0.0-notes.rst
index 2c80c84a6..e9fd7f90a 100644
--- a/doc/release/2.0.0-notes.rst
+++ b/doc/release/2.0.0-notes.rst
@@ -130,6 +130,12 @@ Support for mask-based NA values in the polynomial package fits
The fitting functions recognize and remove masked data from the fit.
+New argument to searchsorted
+----------------------------
+
+The function searchsorted now accepts a 'sorter' argument that is a
+permuation array that sorts the array to search.
+
Changes
=======
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 6b7d98c4d..d6bd881cc 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -4082,7 +4082,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('round',
add_newdoc('numpy.core.multiarray', 'ndarray', ('searchsorted',
"""
- a.searchsorted(v, side='left')
+ a.searchsorted(v, side='left', sorter=None)
Find indices where elements of v should be inserted in a to maintain order.
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 8757b85bf..dae109a98 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -755,24 +755,30 @@ def argmin(a, axis=None):
return argmin(axis)
-def searchsorted(a, v, side='left'):
+def searchsorted(a, v, side='left', sorter=None):
"""
Find indices where elements should be inserted to maintain order.
- Find the indices into a sorted array `a` such that, if the corresponding
- elements in `v` were inserted before the indices, the order of `a` would
- be preserved.
+ Find the indices into a sorted array `a` such that, if the
+ corresponding elements in `v` were inserted before the indices, the
+ order of `a` would be preserved.
Parameters
----------
a : 1-D array_like
- Input array, sorted in ascending order.
+ Input array. If `sorter` is None, then it must be sorted in
+ ascending order, otherwise `sorter` must be an array of indices
+ that sort it.
v : array_like
Values to insert into `a`.
side : {'left', 'right'}, optional
- If 'left', the index of the first suitable location found is given. If
- 'right', return the last such index. If there is no suitable
+ If 'left', the index of the first suitable location found is given.
+ If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
+ sorter : 1-D array_like, optional
+ .. versionadded:: 1.7.0
+ Optional array of integer indices that sort array a into ascending
+ order. They are typically the result of argsort.
Returns
-------
@@ -804,8 +810,8 @@ def searchsorted(a, v, side='left'):
try:
searchsorted = a.searchsorted
except AttributeError:
- return _wrapit(a, 'searchsorted', v, side)
- return searchsorted(v, side)
+ return _wrapit(a, 'searchsorted', v, side, sorter)
+ return searchsorted(v, side, sorter)
def resize(a, new_shape):
@@ -2693,4 +2699,3 @@ def var(a, axis=None, dtype=None, out=None, ddof=0,
return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
skipna=skipna, keepdims=keepdims)
-
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 00398e49b..c7eb5f87b 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1580,6 +1580,95 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
}
}
+/** @brief Use bisection of sorted array to find first entries >= keys.
+ *
+ * For each key use bisection to find the first index i s.t. key <= arr[i].
+ * When there is no such index i, set i = len(arr). Return the results in ret.
+ * All arrays are assumed contiguous on entry and both arr and key must be of
+ * the same comparable type.
+ *
+ * @param arr contiguous sorted array to be searched.
+ * @param key contiguous array of keys.
+ * @param ret contiguous array of intp for returned indices.
+ * @return void
+ */
+static void
+local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
+ PyArrayObject *sorter, PyArrayObject *ret)
+{
+ PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
+ npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
+ npy_intp nkeys = PyArray_SIZE(key);
+ char *parr = PyArray_DATA(arr);
+ char *pkey = PyArray_DATA(key);
+ npy_intp *psorter = PyArray_DATA(sorter);
+ npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
+ int elsize = PyArray_DESCR(arr)->elsize;
+ npy_intp i;
+
+ for (i = 0; i < nkeys; ++i) {
+ npy_intp imin = 0;
+ npy_intp imax = nelts;
+ while (imin < imax) {
+ npy_intp imid = imin + ((imax - imin) >> 1);
+ if (compare(parr + elsize*psorter[imid], pkey, key) < 0) {
+ imin = imid + 1;
+ }
+ else {
+ imax = imid;
+ }
+ }
+ *pret = imin;
+ pret += 1;
+ pkey += elsize;
+ }
+}
+
+
+/** @brief Use bisection of sorted array to find first entries > keys.
+ *
+ * For each key use bisection to find the first index i s.t. key < arr[i].
+ * When there is no such index i, set i = len(arr). Return the results in ret.
+ * All arrays are assumed contiguous on entry and both arr and key must be of
+ * the same comparable type.
+ *
+ * @param arr contiguous sorted array to be searched.
+ * @param key contiguous array of keys.
+ * @param ret contiguous array of intp for returned indices.
+ * @return void
+ */
+static void
+local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
+ PyArrayObject *sorter, PyArrayObject *ret)
+{
+ PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
+ npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
+ npy_intp nkeys = PyArray_SIZE(key);
+ char *parr = PyArray_DATA(arr);
+ char *pkey = PyArray_DATA(key);
+ npy_intp *psorter = PyArray_DATA(sorter);
+ npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
+ int elsize = PyArray_DESCR(arr)->elsize;
+ npy_intp i;
+
+ for(i = 0; i < nkeys; ++i) {
+ npy_intp imin = 0;
+ npy_intp imax = nelts;
+ while (imin < imax) {
+ npy_intp imid = imin + ((imax - imin) >> 1);
+ if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) {
+ imin = imid + 1;
+ }
+ else {
+ imax = imid;
+ }
+ }
+ *pret = imin;
+ pret += 1;
+ pkey += elsize;
+ }
+}
+
/*NUMPY_API
*
* Search the sorted array op1 for the location of the items in op2. The
@@ -1596,6 +1685,8 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
* side : {NPY_SEARCHLEFT, NPY_SEARCHRIGHT}
* If NPY_SEARCHLEFT, return first valid insertion indexes
* If NPY_SEARCHRIGHT, return last valid insertion indexes
+ * perm : PyObject *
+ * Permutation array that sorts op1 (optional)
*
* Returns
* -------
@@ -1608,10 +1699,15 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
* Binary search is used to find the indexes.
*/
NPY_NO_EXPORT PyObject *
-PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side)
+PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
+ NPY_SEARCHSIDE side, PyObject *perm)
{
PyArrayObject *ap1 = NULL;
PyArrayObject *ap2 = NULL;
+ PyArrayObject *ap3 = NULL;
+ PyArrayObject *sorter = NULL;
+ PyObject *max = NULL;
+ PyObject *min = NULL;
PyArrayObject *ret = NULL;
PyArray_Descr *dtype;
NPY_BEGIN_THREADS_DEF;
@@ -1625,8 +1721,9 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side)
/* need ap1 as contiguous array and of right type */
Py_INCREF(dtype);
ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
- 1, 1, NPY_ARRAY_DEFAULT |
- NPY_ARRAY_NOTSWAPPED, NULL);
+ 1, 1,
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
+ NULL);
if (ap1 == NULL) {
Py_DECREF(dtype);
return NULL;
@@ -1634,11 +1731,62 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side)
/* need ap2 as contiguous array and of right type */
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
- 0, 0, NPY_ARRAY_DEFAULT |
- NPY_ARRAY_NOTSWAPPED, NULL);
+ 0, 0,
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
+ NULL);
if (ap2 == NULL) {
goto fail;
}
+ /* check that comparison function exists */
+ if (PyArray_DESCR(ap2)->f->compare == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "compare not supported for type");
+ goto fail;
+ }
+
+ if (perm) {
+ /* need ap3 as contiguous array and of right type */
+ ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL,
+ 1, 1,
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
+ NULL);
+ if (ap3 == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "could not parse sorter argument");
+ goto fail;
+ }
+ if (!PyArray_ISINTEGER(ap3)) {
+ PyErr_SetString(PyExc_TypeError,
+ "sorter must only contain integers");
+ goto fail;
+ }
+ /* convert to known integer size */
+ sorter = (PyArrayObject *)PyArray_FromArray(ap3,
+ PyArray_DescrFromType(NPY_INTP),
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED);
+ if (sorter == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "could not parse sorter argument");
+ goto fail;
+ }
+ if (PyArray_SIZE(sorter) != PyArray_SIZE(ap1)) {
+ PyErr_SetString(PyExc_ValueError,
+ "sorter.size must equal a.size");
+ goto fail;
+ }
+ max = PyArray_Max(sorter, 0, NULL);
+ if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) {
+ PyErr_SetString(PyExc_ValueError,
+ "sorter.max() must be less than a.size");
+ goto fail;
+ }
+ min = PyArray_Min(sorter, 0, NULL);
+ if (PyLong_AsLong(min) < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "sorter elements must be non-negative");
+ goto fail;
+ }
+ }
/* ret is a contiguous array of intp type to hold returned indices */
ret = (PyArrayObject *)PyArray_New(Py_TYPE(ap2), PyArray_NDIM(ap2),
@@ -1647,22 +1795,32 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side)
if (ret == NULL) {
goto fail;
}
- /* check that comparison function exists */
- if (PyArray_DESCR(ap2)->f->compare == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "compare not supported for type");
- goto fail;
- }
- if (side == NPY_SEARCHLEFT) {
- NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_search_left(ap1, ap2, ret);
- NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ if (ap3 == NULL) {
+ if (side == NPY_SEARCHLEFT) {
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
+ local_search_left(ap1, ap2, ret);
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ }
+ else if (side == NPY_SEARCHRIGHT) {
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
+ local_search_right(ap1, ap2, ret);
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ }
}
- else if (side == NPY_SEARCHRIGHT) {
- NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
- local_search_right(ap1, ap2, ret);
- NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ else {
+ if (side == NPY_SEARCHLEFT) {
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
+ local_argsearch_left(ap1, ap2, sorter, ret);
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ }
+ else if (side == NPY_SEARCHRIGHT) {
+ NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
+ local_argsearch_right(ap1, ap2, sorter, ret);
+ NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
+ }
+ Py_DECREF(ap3);
+ Py_DECREF(sorter);
}
Py_DECREF(ap1);
Py_DECREF(ap2);
@@ -1671,6 +1829,8 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side)
fail:
Py_XDECREF(ap1);
Py_XDECREF(ap2);
+ Py_XDECREF(ap3);
+ Py_XDECREF(sorter);
Py_XDECREF(ret);
return NULL;
}
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 6de8b10b2..ddb1cc052 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1361,16 +1361,21 @@ array_argsort(PyArrayObject *self, PyObject *args, PyObject *kwds)
static PyObject *
array_searchsorted(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- static char *kwlist[] = {"keys", "side", NULL};
+ static char *kwlist[] = {"keys", "side", "sorter", NULL};
PyObject *keys;
+ PyObject *sorter;
NPY_SEARCHSIDE side = NPY_SEARCHLEFT;
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&:searchsorted",
+ sorter = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O:searchsorted",
kwlist, &keys,
- PyArray_SearchsideConverter, &side)) {
+ PyArray_SearchsideConverter, &side, &sorter)) {
return NULL;
}
- return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side));
+ if (sorter == Py_None) {
+ sorter = NULL;
+ }
+ return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side, sorter));
}
static void
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 053115b3c..b3b09a7ea 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -746,6 +746,28 @@ class TestMethods(TestCase):
assert_equal([a.searchsorted(a[i], 'left') for i in ind], ind)
assert_equal([a.searchsorted(a[i], 'right') for i in ind], ind + 1)
+ def test_searchsorted_with_sorter(self):
+ a = np.array([5,2,1,3,4])
+ s = np.argsort(a)
+ assert_raises(TypeError, np.searchsorted, a, 0, sorter=(1,(2,3)))
+ assert_raises(TypeError, np.searchsorted, a, 0, sorter=[1.1])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,6])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,5])
+ assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,2,3,4])
+
+ a = np.random.rand(100)
+ s = a.argsort()
+ b = np.sort(a)
+ k = np.linspace(0, 1, 20)
+ assert_equal(b.searchsorted(k), a.searchsorted(k, sorter=s))
+
+ a = np.array([0, 1, 2, 3, 5]*20)
+ s = a.argsort()
+ k = [0, 1, 2, 3, 5]
+ assert_equal(a.searchsorted(k, side='l', sorter=s), [0, 20, 40, 60, 80])
+ assert_equal(a.searchsorted(k, side='r', sorter=s), [20, 40, 60, 80, 100])
+
def test_flatten(self):
x0 = np.array([[1,2,3],[4,5,6]], np.int32)
x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32)