diff options
-rw-r--r-- | doc/release/2.0.0-notes.rst | 6 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 2 | ||||
-rw-r--r-- | numpy/core/fromnumeric.py | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 198 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 22 |
6 files changed, 232 insertions, 34 deletions
diff --git a/doc/release/2.0.0-notes.rst b/doc/release/2.0.0-notes.rst index 2c80c84a6..e9fd7f90a 100644 --- a/doc/release/2.0.0-notes.rst +++ b/doc/release/2.0.0-notes.rst @@ -130,6 +130,12 @@ Support for mask-based NA values in the polynomial package fits The fitting functions recognize and remove masked data from the fit. +New argument to searchsorted +---------------------------- + +The function searchsorted now accepts a 'sorter' argument that is a +permuation array that sorts the array to search. + Changes ======= diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 6b7d98c4d..d6bd881cc 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -4082,7 +4082,7 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('round', add_newdoc('numpy.core.multiarray', 'ndarray', ('searchsorted', """ - a.searchsorted(v, side='left') + a.searchsorted(v, side='left', sorter=None) Find indices where elements of v should be inserted in a to maintain order. diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 8757b85bf..dae109a98 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -755,24 +755,30 @@ def argmin(a, axis=None): return argmin(axis) -def searchsorted(a, v, side='left'): +def searchsorted(a, v, side='left', sorter=None): """ Find indices where elements should be inserted to maintain order. - Find the indices into a sorted array `a` such that, if the corresponding - elements in `v` were inserted before the indices, the order of `a` would - be preserved. + Find the indices into a sorted array `a` such that, if the + corresponding elements in `v` were inserted before the indices, the + order of `a` would be preserved. Parameters ---------- a : 1-D array_like - Input array, sorted in ascending order. + Input array. If `sorter` is None, then it must be sorted in + ascending order, otherwise `sorter` must be an array of indices + that sort it. v : array_like Values to insert into `a`. side : {'left', 'right'}, optional - If 'left', the index of the first suitable location found is given. If - 'right', return the last such index. If there is no suitable + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable index, return either 0 or N (where N is the length of `a`). + sorter : 1-D array_like, optional + .. versionadded:: 1.7.0 + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. Returns ------- @@ -804,8 +810,8 @@ def searchsorted(a, v, side='left'): try: searchsorted = a.searchsorted except AttributeError: - return _wrapit(a, 'searchsorted', v, side) - return searchsorted(v, side) + return _wrapit(a, 'searchsorted', v, side, sorter) + return searchsorted(v, side, sorter) def resize(a, new_shape): @@ -2693,4 +2699,3 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, skipna=skipna, keepdims=keepdims) - diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 00398e49b..c7eb5f87b 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1580,6 +1580,95 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) } } +/** @brief Use bisection of sorted array to find first entries >= keys. + * + * For each key use bisection to find the first index i s.t. key <= arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. + * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void + */ +static void +local_argsearch_left(PyArrayObject *arr, PyArrayObject *key, + PyArrayObject *sorter, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare; + npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1]; + npy_intp nkeys = PyArray_SIZE(key); + char *parr = PyArray_DATA(arr); + char *pkey = PyArray_DATA(key); + npy_intp *psorter = PyArray_DATA(sorter); + npy_intp *pret = (npy_intp *)PyArray_DATA(ret); + int elsize = PyArray_DESCR(arr)->elsize; + npy_intp i; + + for (i = 0; i < nkeys; ++i) { + npy_intp imin = 0; + npy_intp imax = nelts; + while (imin < imax) { + npy_intp imid = imin + ((imax - imin) >> 1); + if (compare(parr + elsize*psorter[imid], pkey, key) < 0) { + imin = imid + 1; + } + else { + imax = imid; + } + } + *pret = imin; + pret += 1; + pkey += elsize; + } +} + + +/** @brief Use bisection of sorted array to find first entries > keys. + * + * For each key use bisection to find the first index i s.t. key < arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. + * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void + */ +static void +local_argsearch_right(PyArrayObject *arr, PyArrayObject *key, + PyArrayObject *sorter, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare; + npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1]; + npy_intp nkeys = PyArray_SIZE(key); + char *parr = PyArray_DATA(arr); + char *pkey = PyArray_DATA(key); + npy_intp *psorter = PyArray_DATA(sorter); + npy_intp *pret = (npy_intp *)PyArray_DATA(ret); + int elsize = PyArray_DESCR(arr)->elsize; + npy_intp i; + + for(i = 0; i < nkeys; ++i) { + npy_intp imin = 0; + npy_intp imax = nelts; + while (imin < imax) { + npy_intp imid = imin + ((imax - imin) >> 1); + if (compare(parr + elsize*psorter[imid], pkey, key) <= 0) { + imin = imid + 1; + } + else { + imax = imid; + } + } + *pret = imin; + pret += 1; + pkey += elsize; + } +} + /*NUMPY_API * * Search the sorted array op1 for the location of the items in op2. The @@ -1596,6 +1685,8 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * side : {NPY_SEARCHLEFT, NPY_SEARCHRIGHT} * If NPY_SEARCHLEFT, return first valid insertion indexes * If NPY_SEARCHRIGHT, return last valid insertion indexes + * perm : PyObject * + * Permutation array that sorts op1 (optional) * * Returns * ------- @@ -1608,10 +1699,15 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) * Binary search is used to find the indexes. */ NPY_NO_EXPORT PyObject * -PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) +PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, + NPY_SEARCHSIDE side, PyObject *perm) { PyArrayObject *ap1 = NULL; PyArrayObject *ap2 = NULL; + PyArrayObject *ap3 = NULL; + PyArrayObject *sorter = NULL; + PyObject *max = NULL; + PyObject *min = NULL; PyArrayObject *ret = NULL; PyArray_Descr *dtype; NPY_BEGIN_THREADS_DEF; @@ -1625,8 +1721,9 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) /* need ap1 as contiguous array and of right type */ Py_INCREF(dtype); ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype, - 1, 1, NPY_ARRAY_DEFAULT | - NPY_ARRAY_NOTSWAPPED, NULL); + 1, 1, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); if (ap1 == NULL) { Py_DECREF(dtype); return NULL; @@ -1634,11 +1731,62 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) /* need ap2 as contiguous array and of right type */ ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype, - 0, 0, NPY_ARRAY_DEFAULT | - NPY_ARRAY_NOTSWAPPED, NULL); + 0, 0, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); if (ap2 == NULL) { goto fail; } + /* check that comparison function exists */ + if (PyArray_DESCR(ap2)->f->compare == NULL) { + PyErr_SetString(PyExc_TypeError, + "compare not supported for type"); + goto fail; + } + + if (perm) { + /* need ap3 as contiguous array and of right type */ + ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL, + 1, 1, + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED, + NULL); + if (ap3 == NULL) { + PyErr_SetString(PyExc_TypeError, + "could not parse sorter argument"); + goto fail; + } + if (!PyArray_ISINTEGER(ap3)) { + PyErr_SetString(PyExc_TypeError, + "sorter must only contain integers"); + goto fail; + } + /* convert to known integer size */ + sorter = (PyArrayObject *)PyArray_FromArray(ap3, + PyArray_DescrFromType(NPY_INTP), + NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED); + if (sorter == NULL) { + PyErr_SetString(PyExc_ValueError, + "could not parse sorter argument"); + goto fail; + } + if (PyArray_SIZE(sorter) != PyArray_SIZE(ap1)) { + PyErr_SetString(PyExc_ValueError, + "sorter.size must equal a.size"); + goto fail; + } + max = PyArray_Max(sorter, 0, NULL); + if (PyLong_AsLong(max) >= PyArray_SIZE(ap1)) { + PyErr_SetString(PyExc_ValueError, + "sorter.max() must be less than a.size"); + goto fail; + } + min = PyArray_Min(sorter, 0, NULL); + if (PyLong_AsLong(min) < 0) { + PyErr_SetString(PyExc_ValueError, + "sorter elements must be non-negative"); + goto fail; + } + } /* ret is a contiguous array of intp type to hold returned indices */ ret = (PyArrayObject *)PyArray_New(Py_TYPE(ap2), PyArray_NDIM(ap2), @@ -1647,22 +1795,32 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) if (ret == NULL) { goto fail; } - /* check that comparison function exists */ - if (PyArray_DESCR(ap2)->f->compare == NULL) { - PyErr_SetString(PyExc_TypeError, - "compare not supported for type"); - goto fail; - } - if (side == NPY_SEARCHLEFT) { - NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_search_left(ap1, ap2, ret); - NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + if (ap3 == NULL) { + if (side == NPY_SEARCHLEFT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_search_left(ap1, ap2, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + else if (side == NPY_SEARCHRIGHT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_search_right(ap1, ap2, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } } - else if (side == NPY_SEARCHRIGHT) { - NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); - local_search_right(ap1, ap2, ret); - NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + else { + if (side == NPY_SEARCHLEFT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_argsearch_left(ap1, ap2, sorter, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + else if (side == NPY_SEARCHRIGHT) { + NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2)); + local_argsearch_right(ap1, ap2, sorter, ret); + NPY_END_THREADS_DESCR(PyArray_DESCR(ap2)); + } + Py_DECREF(ap3); + Py_DECREF(sorter); } Py_DECREF(ap1); Py_DECREF(ap2); @@ -1671,6 +1829,8 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) fail: Py_XDECREF(ap1); Py_XDECREF(ap2); + Py_XDECREF(ap3); + Py_XDECREF(sorter); Py_XDECREF(ret); return NULL; } diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 6de8b10b2..ddb1cc052 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1361,16 +1361,21 @@ array_argsort(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * array_searchsorted(PyArrayObject *self, PyObject *args, PyObject *kwds) { - static char *kwlist[] = {"keys", "side", NULL}; + static char *kwlist[] = {"keys", "side", "sorter", NULL}; PyObject *keys; + PyObject *sorter; NPY_SEARCHSIDE side = NPY_SEARCHLEFT; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&:searchsorted", + sorter = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O:searchsorted", kwlist, &keys, - PyArray_SearchsideConverter, &side)) { + PyArray_SearchsideConverter, &side, &sorter)) { return NULL; } - return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side)); + if (sorter == Py_None) { + sorter = NULL; + } + return PyArray_Return((PyArrayObject *)PyArray_SearchSorted(self, keys, side, sorter)); } static void diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 053115b3c..b3b09a7ea 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -746,6 +746,28 @@ class TestMethods(TestCase): assert_equal([a.searchsorted(a[i], 'left') for i in ind], ind) assert_equal([a.searchsorted(a[i], 'right') for i in ind], ind + 1) + def test_searchsorted_with_sorter(self): + a = np.array([5,2,1,3,4]) + s = np.argsort(a) + assert_raises(TypeError, np.searchsorted, a, 0, sorter=(1,(2,3))) + assert_raises(TypeError, np.searchsorted, a, 0, sorter=[1.1]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,6]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[1,2,3,4,5,5]) + assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,2,3,4]) + + a = np.random.rand(100) + s = a.argsort() + b = np.sort(a) + k = np.linspace(0, 1, 20) + assert_equal(b.searchsorted(k), a.searchsorted(k, sorter=s)) + + a = np.array([0, 1, 2, 3, 5]*20) + s = a.argsort() + k = [0, 1, 2, 3, 5] + assert_equal(a.searchsorted(k, side='l', sorter=s), [0, 20, 40, 60, 80]) + assert_equal(a.searchsorted(k, side='r', sorter=s), [20, 40, 60, 80, 100]) + def test_flatten(self): x0 = np.array([[1,2,3],[4,5,6]], np.int32) x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32) |