diff options
author | Bryan Van de Ven <bryan@laptop.local> | 2012-03-22 10:24:52 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2012-04-04 10:15:45 -0600 |
commit | 313fe46046a7192cbdba2e679a104777301bc7cf (patch) | |
tree | acdfcd44ceeaed40da585c373f2bfca94bf30a2e | |
parent | 0d1b60136862dd831ca586516d47561181e321ec (diff) | |
download | numpy-313fe46046a7192cbdba2e679a104777301bc7cf.tar.gz |
ENH: Add 'sorter' argument to searchsorted.
The new argument allows one to search an argsorted array by passing
in the result of argsorting the array as the 'sorter' argument. For
example
searchsorted(a, sorter=a.argsort)
-rw-r--r-- | doc/release/2.0.0-notes.rst | 6 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 2 | ||||
-rw-r--r-- | numpy/core/fromnumeric.py | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 198 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 22 |
6 files changed, 232 insertions, 34 deletions
diff --git a/doc/release/2.0.0-notes.rst b/doc/release/2.0.0-notes.rst index 2c80c84a6..e9fd7f90a 100644 --- a/doc/release/2.0.0-notes.rst +++ b/doc/release/2.0.0-notes.rst @@ -130,6 +130,12 @@ Support for mask-based NA values in the polynomial package fits The fitting functions recognize and remove masked data from the fit. +New argument to searchsorted +---------------------------- + +The function searchsorted now accepts a 'sorter' argument that is a +permuation array that sorts the array to search. + Changes ======= diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 6b7d98c4d..d6bd881cc 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -4082,7 +4082,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('round', add_newdoc('numpy.core.multiarray', 'ndarray', ('searchsorted', """ - a.searchsorted(v, side='left') + a.searchsorted(v, side='left', sorter=None) Find indices where elements of v should be inserted in a to maintain order. diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 8757b85bf..dae109a98 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -755,24 +755,30 @@ def argmin(a, axis=None): return argmin(axis) -def searchsorted(a, v, side='left'): +def searchsorted(a, v, side='left', sorter=None): """ Find indices where elements should be inserted to maintain order. - Find the indices into a sorted array `a` such that, if the corresponding - elements in `v` were inserted before the indices, the order of `a` would - be preserved. + Find the indices into a sorted array `a` such that, if the + corresponding elements in `v` were inserted before the indices, the + order of `a` would be preserved. Parameters ---------- a : 1-D array_like - Input array, sorted in ascending order. + Input array. If `sorter` is None, then it must be sorted in + ascending order, otherwise `sorter` must be an array of indices + that sort it. v : array_like Values to insert into `a`. side : {'left', 'right'}, optional - If 'left', the index of the first suitable location found is given. If - 'right', return the last such index. If there is no suitable + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable index, return either 0 or N (where N is the length of `a`). + sorter : 1-D array_like, optional + .. versionadded:: 1.7.0 + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. Returns ------- @@ -804,8 +810,8 @@ def searchsorted(a, v, side='left'): try: searchsorted = a.searchsorted except AttributeError: - return _wrapit(a, 'searchsorted', v, side) - return searchsorted(v, side) + return _wrapit(a, 'searchsorted', v, side, sorter) + return searchsorted(v, side, sorter) def resize(a, new_shape): @@ -2693,4 +2699,3 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, skipna=skipna, keepdims=keepdims) - diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 00398e49b..c7eb5f87b 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1580,6 +1580,95 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) } } +/** @brief Use bisection of sorted array to find first entries >= keys. + * + * For each key use bisection to find the first index i s.t. key <= arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. + * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void + */ +static void +local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, + PyArrayObject *sorter, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare; + npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1]; + npy_intp nkeys = PyArray_SIZE(key); + char *parr = PyArray_DATA(arr); + char *pkey = PyArray_DATA(key); + npy_intp *psorter = PyArray_DATA(sorter); + npy_intp *pret = (npy_intp *)PyArray_DATA(ret); + int elsize = PyArray_DESCR(arr)->elsize; + npy_intp i; + + for (i = 0; i < nkeys; ++i) { + npy_intp imin = 0; + npy_intp imax = nelts; + while (imin < imax) { + npy_intp imid = imin + ((imax - imin) >> 1); + if (compare(parr + elsize*psorter[imid], pkey, key) < 0) { + imin = imid + 1; + } + else { + imax = imid; + } + } + *pret = imin; + pret += 1; + pkey += elsize; + } +} + + +/** @brief Use bisection of sorted array to find first entries > keys. + * + * For each key use bisection to find the first index i s.t. key < arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. + * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void + */ +static void +local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, + PyArrayObject *sorter, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare; + npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1]; + npy_intp nkeys = PyArray_SIZE(key); + char *parr = PyArray_DATA(arr); + char *pkey = PyArray_DATA(key); + npy_intp *psorter = PyArray_DATA(sorter); + npy_intp *pret = (npy_intp *)PyArray_DATA(ret); + int elsize = PyArray_DESCR(arr)->elsize; + npy_intp i; + + for(i = 0; i < nkeys; ++i) { + npy_intp imin = 0; + npy_intp imax = nelts; + while (imin < imax) { + npy_intp imid = imin + ((imax - imin) >> 1); + if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) { + imin = imid + 1; + } + else { + imax = imid; + } + } + *pret = imin; + pret += 1; + pkey += elsize; + } +} + /*NUMPY_API * * Search the sorted array op1 for the location of the items in op2. The @@ -1596,6 +1685,8 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * side : {NPY_SEARCHLEFT, NPY_SEARCHRIGHT} * If NPY_SEARCHLEFT, return first valid insertion indexes * If NPY_SEARCHRIGHT, return last valid insertion indexes + * perm : PyObject * + * Permutation array that sorts op1 (optional) * * Returns * ------- @@ -1608,10 +1699,15 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * Binary search is used to find the indexes. */ NPY_NO_EXPORT PyObject * -PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) +PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, + NPY_SEARCHSIDE side, PyObject *perm) { PyArrayObject *ap1 = NULL; PyArrayObject *ap2 = NULL; + PyArrayObject *ap3 = NULL; + PyArrayObject *sorter = NULL; + PyObject *max = NULL; + PyObject *min = NULL; PyArrayObject *ret = NULL; PyArray_Descr *dtype; NPY_BEGIN_THREADS_DEF; @@ -1625,8 +1721,9 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) /* need ap1 as contiguous array and of right type */ Py_INCREF(dtype); ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype, - 1, 1, NPY_ARRAY_DEFAULT | - NPY_ARRAY_NOTSWAPPED, NULL); + 1, 1, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); if (ap1 == NULL) { Py_DECREF(dtype); return NULL; @@ -1634,11 +1731,62 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) /* need ap2 as contiguous array and of right type */ ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype, - 0, 0, NPY_ARRAY_DEFAULT | - NPY_ARRAY_NOTSWAPPED, NULL); + 0, 0, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); if (ap2 == NULL) { goto fail; } + /* check that comparison function exists */ + if (PyArray_DESCR(ap2)->f->compare == NULL) { + PyErr_SetString(PyExc_TypeError, + "compare not supported for type"); + goto fail; + } + + if (perm) { + /* need ap3 as contiguous array and of right type */ + ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL, + 1, 1, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); + if (ap3 == NULL) { + PyErr_SetString(PyExc_TypeError, + "could not parse sorter argument"); + goto fail; + } + if (!PyArray_ISINTEGER(ap3)) { + PyErr_SetString(PyExc_TypeError, + "sorter must only contain integers"); + goto fail; + } + /* convert to known integer size */ + sorter = (PyArrayObject *)PyArray_FromArray(ap3, + PyArray_DescrFromType(NPY_INTP), + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED); + if (sorter == NULL) { + PyErr_SetString(PyExc_ValueError, + "could not parse sorter argument"); + goto fail; + } + if (PyArray_SIZE(sorter) != PyArray_SIZE(ap1)) { + PyErr_SetString(PyExc_ValueError, + "sorter.size must equal a.size"); + goto fail; + } + max = PyArray_Max(sorter, 0, NULL); + if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) { + PyErr_SetString(PyExc_ValueError, + "sorter.max() must be less than a.size"); + goto fail; + } + min = PyArray_Min(sorter, 0, NULL); + if (PyLong_AsLong(min) < 0) { + PyErr_SetString(PyExc_ValueError, + "sorter elements must be non-negative"); + goto fail; + } + } /* ret is a contiguous array of intp type to hold returned indices */ ret = (PyArrayObject *)PyArray_New(Py_TYPE(ap2), PyArray_NDIM(ap2), @@ -1647,22 +1795,32 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) if (ret == NULL) { goto fail; } - /* check that comparison function exists */ - if (PyArray_DESCR(ap2)->f->compare == NULL) { - PyErr_SetString(PyExc_TypeError, - "compare not supported for type"); - goto fail; - } - if (side == NPY_SEARCHLEFT) { - NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_search_left(ap1, ap2, ret); - NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + if (ap3 == NULL) { + if (side == NPY_SEARCHLEFT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_search_left(ap1, ap2, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + else if (side == NPY_SEARCHRIGHT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_search_right(ap1, ap2, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } } - else if (side == NPY_SEARCHRIGHT) { - NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_search_right(ap1, ap2, ret); - NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + else { + if (side == NPY_SEARCHLEFT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_argsearch_left(ap1, ap2, sorter, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + else if (side == NPY_SEARCHRIGHT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_argsearch_right(ap1, ap2, sorter, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + Py_DECREF(ap3); + Py_DECREF(sorter); } Py_DECREF(ap1); Py_DECREF(ap2); @@ -1671,6 +1829,8 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) fail: Py_XDECREF(ap1); Py_XDECREF(ap2); + Py_XDECREF(ap3); + Py_XDECREF(sorter); Py_XDECREF(ret); return NULL; } diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 6de8b10b2..ddb1cc052 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1361,16 +1361,21 @@ array_argsort(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * array_searchsorted(PyArrayObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"keys", "side", NULL}; + static char *kwlist[] = {"keys", "side", "sorter", NULL}; PyObject *keys; + PyObject *sorter; NPY_SEARCHSIDE side = NPY_SEARCHLEFT; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&:searchsorted", + sorter = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O:searchsorted", kwlist, &keys, - PyArray_SearchsideConverter, &side)) { + PyArray_SearchsideConverter, &side, &sorter)) { return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side)); + if (sorter == Py_None) { + sorter = NULL; + } + return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side, sorter)); } static void diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 053115b3c..b3b09a7ea 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -746,6 +746,28 @@ class TestMethods(TestCase): assert_equal([a.searchsorted(a[i], 'left') for i in ind], ind) assert_equal([a.searchsorted(a[i], 'right') for i in ind], ind + 1) + def test_searchsorted_with_sorter(self): + a = np.array([5,2,1,3,4]) + s = np.argsort(a) + assert_raises(TypeError, np.searchsorted, a, 0, sorter=(1,(2,3))) + assert_raises(TypeError, np.searchsorted, a, 0, sorter=[1.1]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,6]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,5]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,2,3,4]) + + a = np.random.rand(100) + s = a.argsort() + b = np.sort(a) + k = np.linspace(0, 1, 20) + assert_equal(b.searchsorted(k), a.searchsorted(k, sorter=s)) + + a = np.array([0, 1, 2, 3, 5]*20) + s = a.argsort() + k = [0, 1, 2, 3, 5] + assert_equal(a.searchsorted(k, side='l', sorter=s), [0, 20, 40, 60, 80]) + assert_equal(a.searchsorted(k, side='r', sorter=s), [20, 40, 60, 80, 100]) + def test_flatten(self): x0 = np.array([[1,2,3],[4,5,6]], np.int32) x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32) |