diff options
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/arraymethods.c | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 134 |
2 files changed, 82 insertions, 77 deletions
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c index 179e48ce9..9733ae75f 100644 --- a/numpy/core/src/arraymethods.c +++ b/numpy/core/src/arraymethods.c @@ -887,29 +887,14 @@ static PyObject * array_searchsorted(PyArrayObject *self, PyObject *args, PyObject *kwds) { PyObject *keys; - char *side = "left"; static char *kwlist[] = {"keys", "side", NULL}; - NPY_SEARCHKIND which; + NPY_SEARCHSIDE side = NPY_SEARCHLEFT; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|s", kwlist, &keys, &side)) - return NULL; - if (strlen(side) < 1) { - PyErr_SetString(PyExc_ValueError, - "Searchsorted: side must be nonempty string"); - return PY_FAIL; - } - - if (side[0] == 'l' || side[0] == 'L') - which = NPY_SEARCHLEFT; - else if (side[0] == 'r' || side[0] == 'R') - which = NPY_SEARCHRIGHT; - else { - PyErr_Format(PyExc_ValueError, - "%s is not a valid value of side", side); - return PY_FAIL; - } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&", kwlist, &keys, + PyArray_SearchsideConverter, &side)) + return NULL; - return _ARET(PyArray_SearchSorted(self, keys, which)); + return _ARET(PyArray_SearchSorted(self, keys, side)); } static void diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 4dd825ea9..1a3029bd9 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -2523,33 +2523,28 @@ PyArray_LexSort(PyObject *sort_keys, int axis) } -/* local_search_left +/** @brief Use bisection of sorted array to find first entries >= keys. * - * Use bisection to find the indices i into ap1 s.t. - * - * ap1[j] < key <= ap1[i] for all 0 <= j < i and all keys in ap2, - * - * When there is no such index i, set i = len(ap1). Return the results in ret. All - * arrays are assumed contiguous on entry and both ap1 and ap2 are assumed to - * be of the same comparable type. - * - * Arguments: - * - * ap1 -- array to be searched, contiguous on entry and assumed sorted. - * ap2 -- array of keys, contiguous on entry. - * ret -- return array of intp, contiguous on entry. + * For each key use bisection to find the first index i s.t. key <= arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void */ static void -local_search_left(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *ret) -{ - PyArray_CompareFunc *compare = ap2->descr->f->compare; - intp nelts = ap1->dimensions[ap1->nd - 1]; - intp nkeys = PyArray_SIZE(ap2); - char *p1 = ap1->data; - char *p2 = ap2->data; - intp *pr = (intp *)ret->data; - int elsize = ap1->descr->elsize; +local_search_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = key->descr->f->compare; + intp nelts = arr->dimensions[arr->nd - 1]; + intp nkeys = PyArray_SIZE(key); + char *parr = arr->data; + char *pkey = key->data; + intp *pret = (intp *)ret->data; + int elsize = arr->descr->elsize; intp i; for(i = 0; i < nkeys; ++i) { @@ -2557,45 +2552,40 @@ local_search_left(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *ret) intp imax = nelts; while (imin < imax) { intp imid = imin + ((imax - imin) >> 2); - if (compare(p1 + elsize*imid, p2, ap2) < 0) + if (compare(parr + elsize*imid, pkey, key) < 0) imin = imid + 1; else imax = imid; } - *pr = imin; - pr += 1; - p2 += elsize; + *pret = imin; + pret += 1; + pkey += elsize; } } -/* local_search_right - * - * Use bisection to find the indices i into ap1 s.t. - * - * ap1[j] <= key < ap1[i] for all 0 <= j < i and all keys in ap2. - * - * When there is no such index i, set i = len(ap1). Return the results in ret. All - * arrays are assumed contiguous on entry and both ap1 and ap2 are assumed to - * be of the same comparable type. +/** @brief Use bisection of sorted array to find first entries > keys. * - * Arguments: - * - * ap1 -- array to be searched, contiguous on entry and assumed sorted. - * ap2 -- array of keys, contiguous on entry. - * ret -- return array of intp, contiguous on entry. + * For each key use bisection to find the first index i s.t. key < arr[i]. + * When there is no such index i, set i = len(arr). Return the results in ret. + * All arrays are assumed contiguous on entry and both arr and key must be of + * the same comparable type. * + * @param arr contiguous sorted array to be searched. + * @param key contiguous array of keys. + * @param ret contiguous array of intp for returned indices. + * @return void */ static void -local_search_right(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *ret) -{ - PyArray_CompareFunc *compare = ap2->descr->f->compare; - intp nelts = ap1->dimensions[ap1->nd - 1]; - intp nkeys = PyArray_SIZE(ap2); - char *p1 = ap1->data; - char *p2 = ap2->data; - intp *pr = (intp *)ret->data; - int elsize = ap1->descr->elsize; +local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret) +{ + PyArray_CompareFunc *compare = key->descr->f->compare; + intp nelts = arr->dimensions[arr->nd - 1]; + intp nkeys = PyArray_SIZE(key); + char *parr = arr->data; + char *pkey = key->data; + intp *pret = (intp *)ret->data; + int elsize = arr->descr->elsize; intp i; for(i = 0; i < nkeys; ++i) { @@ -2603,22 +2593,52 @@ local_search_right(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject *ret) intp imax = nelts; while (imin < imax) { intp imid = imin + ((imax - imin) >> 2); - if (compare(p1 + elsize*imid, p2, ap2) <= 0) + if (compare(parr + elsize*imid, pkey, key) <= 0) imin = imid + 1; else imax = imid; } - *pr = imin; - pr += 1; - p2 += elsize; + *pret = imin; + pret += 1; + pkey += elsize; + } +} + + +/*MULTIARRAY_API + Convert object to searchsorted side +*/ +static int +PyArray_SearchsideConverter(PyObject *obj, NPY_SEARCHSIDE *side) +{ + char *str = PyString_AsString(obj); + + if (!str) + return PY_FAIL; + if (strlen(str) < 1) { + PyErr_SetString(PyExc_ValueError, + "side must be nonempty string"); + return PY_FAIL; } + + if (str[0] == 'l' || str[0] == 'L') + *side = NPY_SEARCHLEFT; + else if (str[0] == 'r' || str[0] == 'R') + *side = NPY_SEARCHRIGHT; + else { + PyErr_Format(PyExc_ValueError, + "side has invalid value '%s'", str); + return PY_FAIL; + } + return PY_SUCCEED; } + /*MULTIARRAY_API Numeric.searchsorted(a,v) */ static PyObject * -PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHKIND which) +PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHSIDE side) { PyArrayObject *ap1=NULL; PyArrayObject *ap2=NULL; @@ -2657,12 +2677,12 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2, NPY_SEARCHKIND which) goto fail; } - if (which == NPY_SEARCHLEFT) { + if (side == NPY_SEARCHLEFT) { NPY_BEGIN_THREADS_DESCR(ap2->descr) local_search_left(ap1, ap2, ret); NPY_END_THREADS_DESCR(ap2->descr) } - else if (which == NPY_SEARCHRIGHT) { + else if (side == NPY_SEARCHRIGHT) { NPY_BEGIN_THREADS_DESCR(ap2->descr) local_search_right(ap1, ap2, ret); NPY_END_THREADS_DESCR(ap2->descr) |