diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-02-24 13:07:56 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-02-24 13:07:56 -0500 |
commit | 4d4fb0d12b9ddfa76e485ea3caa6fac25e7183da (patch) | |
tree | 907da5b3488dd83e06fbd0dbd339283e4baae39a /numpy/core | |
parent | 8c0a50c5b1b331738a8bfb6a94ccc588851d1a0c (diff) | |
parent | a0aaf752c75c310d4eeae1af1064ec8046e6d5e9 (diff) | |
download | numpy-4d4fb0d12b9ddfa76e485ea3caa6fac25e7183da.tar.gz |
Merge pull request #5494 from jaimefrio/argsort_refactor
MAINT: refactor PyArray_ArgSort and PyArrayArgPartition
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 298 | ||||
-rw-r--r-- | numpy/core/src/npysort/heapsort.c.src | 52 | ||||
-rw-r--r-- | numpy/core/src/npysort/mergesort.c.src | 66 | ||||
-rw-r--r-- | numpy/core/src/npysort/quicksort.c.src | 78 | ||||
-rw-r--r-- | numpy/core/src/private/npy_sort.h | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 23 |
6 files changed, 283 insertions, 237 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 7016760b3..e12cd1311 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1421,146 +1421,54 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SEL } -static char *global_data; - -static int -argsort_static_compare(const void *ip1, const void *ip2) -{ - int isize = PyArray_DESCR(global_obj)->elsize; - const npy_intp *ipa = ip1; - const npy_intp *ipb = ip2; - return PyArray_DESCR(global_obj)->f->compare(global_data + (isize * *ipa), - global_data + (isize * *ipb), - global_obj); -} - /*NUMPY_API * ArgSort an array */ NPY_NO_EXPORT PyObject * PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) { - PyArrayObject *ap = NULL, *ret = NULL, *store, *op2; - npy_intp *ip; - npy_intp i, j, n, m, orign; - int argsort_elsize; - char *store_ptr; - int res = 0; - int (*sort)(void *, size_t, size_t, npy_comparator); - - n = PyArray_NDIM(op); - if ((n == 0) || (PyArray_SIZE(op) == 1)) { - ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), - PyArray_DIMS(op), - NPY_INTP, - NULL, NULL, 0, 0, - (PyObject *)op); - if (ret == NULL) { - return NULL; - } - *((npy_intp *)PyArray_DATA(ret)) = 0; - return (PyObject *)ret; - } + PyArrayObject *op2; + PyArray_ArgSortFunc *argsort; + PyObject *ret; - /* Creates new reference op2 */ - if ((op2=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) { + if (which < 0 || which >= NPY_NSORTS) { + PyErr_SetString(PyExc_ValueError, + "not a valid sort kind"); return NULL; } - /* Determine if we should use new algorithm or not */ - if (PyArray_DESCR(op2)->f->argsort[which] != NULL) { - ret = (PyArrayObject *)_new_argsortlike(op2, axis, - PyArray_DESCR(op2)->f->argsort[which], - NULL, NULL, 0); - Py_DECREF(op2); - return (PyObject *)ret; - } - - if (PyArray_DESCR(op2)->f->compare == NULL) { - PyErr_SetString(PyExc_TypeError, - "type does not have compare function"); - Py_DECREF(op2); - op = NULL; - goto fail; - } - - switch (which) { - case NPY_QUICKSORT : - sort = npy_quicksort; - break; - case NPY_HEAPSORT : - sort = npy_heapsort; - break; - case NPY_MERGESORT : - sort = npy_mergesort; - break; - default: - PyErr_SetString(PyExc_TypeError, - "requested sort kind is not supported"); - Py_DECREF(op2); - op = NULL; - goto fail; - } - /* ap will contain the reference to op2 */ - SWAPAXES(ap, op2); - op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap, - NPY_NOTYPE, - 1, 0); - Py_DECREF(ap); - if (op == NULL) { - return NULL; - } - ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), - PyArray_DIMS(op), NPY_INTP, - NULL, NULL, 0, 0, (PyObject *)op); - if (ret == NULL) { - goto fail; - } - ip = (npy_intp *)PyArray_DATA(ret); - argsort_elsize = PyArray_DESCR(op)->elsize; - m = PyArray_DIMS(op)[PyArray_NDIM(op)-1]; - if (m == 0) { - goto finish; - } - n = PyArray_SIZE(op)/m; - store_ptr = global_data; - global_data = PyArray_DATA(op); - store = global_obj; - global_obj = op; - for (i = 0; i < n; i++, ip += m, global_data += m*argsort_elsize) { - for (j = 0; j < m; j++) { - ip[j] = j; + argsort = PyArray_DESCR(op)->f->argsort[which]; + if (argsort == NULL) { + if (PyArray_DESCR(op)->f->compare) { + switch (which) { + default: + case NPY_QUICKSORT: + argsort = npy_aquicksort; + break; + case NPY_HEAPSORT: + argsort = npy_aheapsort; + break; + case NPY_MERGESORT: + argsort = npy_amergesort; + break; + } } - res = sort((char *)ip, m, sizeof(npy_intp), argsort_static_compare); - if (res < 0) { - break; + else { + PyErr_SetString(PyExc_TypeError, + "type does not have compare function"); + return NULL; } } - global_data = store_ptr; - global_obj = store; - if (PyErr_Occurred()) { - goto fail; - } - else if (res == -NPY_ENOMEM) { - PyErr_NoMemory(); - goto fail; - } - else if (res == -NPY_ECOMP) { - PyErr_SetString(PyExc_TypeError, - "sort comparison failed"); - goto fail; + op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0); + if (op2 == NULL) { + return NULL; } - finish: - Py_DECREF(op); - SWAPBACK(op, ret); - return (PyObject *)op; + ret = _new_argsortlike(op2, axis, argsort, NULL, NULL, 0); - fail: - Py_XDECREF(op); - Py_XDECREF(ret); - return NULL; + Py_DECREF(op2); + return ret; } @@ -1568,136 +1476,52 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) * ArgPartition an array */ NPY_NO_EXPORT PyObject * -PyArray_ArgPartition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SELECTKIND which) +PyArray_ArgPartition(PyArrayObject *op, PyArrayObject *ktharray, int axis, + NPY_SELECTKIND which) { - PyArrayObject *ap = NULL, *ret = NULL, *store, *op2; - npy_intp *ip; - npy_intp i, j, n, m, orign; - int argsort_elsize; - char *store_ptr; - int res = 0; - int (*sort)(void *, size_t, size_t, npy_comparator); - PyArray_ArgPartitionFunc * argpart = - get_argpartition_func(PyArray_TYPE(op), which); - - n = PyArray_NDIM(op); - if ((n == 0) || (PyArray_SIZE(op) == 1)) { - ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), - PyArray_DIMS(op), - NPY_INTP, - NULL, NULL, 0, 0, - (PyObject *)op); - if (ret == NULL) { - return NULL; - } - *((npy_intp *)PyArray_DATA(ret)) = 0; - return (PyObject *)ret; - } + PyArrayObject *op2, *kthrvl; + PyArray_ArgPartitionFunc *argpart; + PyArray_ArgSortFunc *argsort; + PyObject *ret; - /* Creates new reference op2 */ - if ((op2=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) { + if (which < 0 || which >= NPY_NSELECTS) { + PyErr_SetString(PyExc_ValueError, + "not a valid partition kind"); return NULL; } - /* Determine if we should use new algorithm or not */ - if (argpart) { - PyArrayObject * kthrvl = partition_prep_kth_array(ktharray, op2, axis); - if (kthrvl == NULL) { - Py_DECREF(op2); + argpart = get_argpartition_func(PyArray_TYPE(op), which); + if (argpart == NULL) { + /* Use sorting, slower but equivalent */ + if (PyArray_DESCR(op)->f->compare) { + argsort = npy_aquicksort; + } + else { + PyErr_SetString(PyExc_TypeError, + "type does not have compare function"); return NULL; } - - ret = (PyArrayObject *)_new_argsortlike(op2, axis, NULL, - argpart, - PyArray_DATA(kthrvl), - PyArray_SIZE(kthrvl)); - Py_DECREF(kthrvl); - Py_DECREF(op2); - return (PyObject *)ret; - } - - if (PyArray_DESCR(op2)->f->compare == NULL) { - PyErr_SetString(PyExc_TypeError, - "type does not have compare function"); - Py_DECREF(op2); - op = NULL; - goto fail; } - /* select not implemented, use quicksort, slower but equivalent */ - switch (which) { - case NPY_INTROSELECT : - sort = npy_quicksort; - break; - default: - PyErr_SetString(PyExc_TypeError, - "requested sort kind is not supported"); - Py_DECREF(op2); - op = NULL; - goto fail; + op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0); + if (op2 == NULL) { + return NULL; } - /* ap will contain the reference to op2 */ - SWAPAXES(ap, op2); - op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap, - NPY_NOTYPE, - 1, 0); - Py_DECREF(ap); - if (op == NULL) { + /* Process ktharray even if using sorting to do bounds checking */ + kthrvl = partition_prep_kth_array(ktharray, op2, axis); + if (kthrvl == NULL) { + Py_DECREF(op2); return NULL; } - ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), - PyArray_DIMS(op), NPY_INTP, - NULL, NULL, 0, 0, (PyObject *)op); - if (ret == NULL) { - goto fail; - } - ip = (npy_intp *)PyArray_DATA(ret); - argsort_elsize = PyArray_DESCR(op)->elsize; - m = PyArray_DIMS(op)[PyArray_NDIM(op)-1]; - if (m == 0) { - goto finish; - } - n = PyArray_SIZE(op)/m; - store_ptr = global_data; - global_data = PyArray_DATA(op); - store = global_obj; - global_obj = op; - /* we don't need to care about kth here as we are using a full sort */ - for (i = 0; i < n; i++, ip += m, global_data += m*argsort_elsize) { - for (j = 0; j < m; j++) { - ip[j] = j; - } - res = sort((char *)ip, m, sizeof(npy_intp), argsort_static_compare); - if (res < 0) { - break; - } - } - global_data = store_ptr; - global_obj = store; - if (PyErr_Occurred()) { - goto fail; - } - else if (res == -NPY_ENOMEM) { - PyErr_NoMemory(); - goto fail; - } - else if (res == -NPY_ECOMP) { - PyErr_SetString(PyExc_TypeError, - "sort comparison failed"); - goto fail; - } + ret = _new_argsortlike(op2, axis, argsort, argpart, + PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); - finish: - Py_DECREF(op); - SWAPBACK(op, ret); - return (PyObject *)op; + Py_DECREF(kthrvl); + Py_DECREF(op2); - fail: - Py_XDECREF(op); - Py_XDECREF(ret); - return NULL; + return ret; } diff --git a/numpy/core/src/npysort/heapsort.c.src b/numpy/core/src/npysort/heapsort.c.src index cfdd3fd2a..bec8cc032 100644 --- a/numpy/core/src/npysort/heapsort.c.src +++ b/numpy/core/src/npysort/heapsort.c.src @@ -343,3 +343,55 @@ npy_heapsort(void *base, size_t num, size_t size, npy_comparator cmp) free(tmp); return 0; } + + +int +npy_aheapsort(char *v, npy_intp *tosort, npy_intp n, PyArrayObject *arr) +{ + npy_intp elsize = PyArray_ITEMSIZE(arr); + PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare; + npy_intp *a, i, j, l, tmp; + + /* The array needs to be offset by one for heapsort indexing */ + a = tosort - 1; + + for (l = n >> 1; l > 0; --l) { + tmp = a[l]; + for (i = l, j = l<<1; j <= n;) { + if (j < n && cmp(v + a[j]*elsize, v + a[j+1]*elsize, arr) < 0) { + ++j; + } + if (cmp(v + tmp*elsize, v + a[j]*elsize, arr) < 0) { + a[i] = a[j]; + i = j; + j += j; + } + else { + break; + } + } + a[i] = tmp; + } + + for (; n > 1;) { + tmp = a[n]; + a[n] = a[1]; + n -= 1; + for (i = 1, j = 2; j <= n;) { + if (j < n && cmp(v + a[j]*elsize, v + a[j+1]*elsize, arr) < 0) { + ++j; + } + if (cmp(v + tmp*elsize, v + a[j]*elsize, arr) < 0) { + a[i] = a[j]; + i = j; + j += j; + } + else { + break; + } + } + a[i] = tmp; + } + + return 0; +} diff --git a/numpy/core/src/npysort/mergesort.c.src b/numpy/core/src/npysort/mergesort.c.src index c99c0e614..ecf39720a 100644 --- a/numpy/core/src/npysort/mergesort.c.src +++ b/numpy/core/src/npysort/mergesort.c.src @@ -426,3 +426,69 @@ fail_1: fail_0: return err; } + + +static void +npy_amergesort0(npy_intp *pl, npy_intp *pr, char *v, npy_intp *pw, + npy_intp elsize, PyArray_CompareFunc *cmp, PyArrayObject *arr) +{ + char *vp; + npy_intp vi, *pi, *pj, *pk, *pm; + + if (pr - pl > SMALL_MERGESORT) { + /* merge sort */ + pm = pl + ((pr - pl) >> 1); + npy_amergesort0(pl, pm, v, pw, elsize, cmp, arr); + npy_amergesort0(pm, pr, v, pw, elsize, cmp, arr); + for (pi = pw, pj = pl; pj < pm;) { + *pi++ = *pj++; + } + pi = pw + (pm - pl); + pj = pw; + pk = pl; + while (pj < pi && pm < pr) { + if (cmp(v + (*pm)*elsize, v + (*pj)*elsize, arr) < 0) { + *pk++ = *pm++; + } + else { + *pk++ = *pj++; + } + } + while (pj < pi) { + *pk++ = *pj++; + } + } + else { + /* insertion sort */ + for (pi = pl + 1; pi < pr; ++pi) { + vi = *pi; + vp = v + vi*elsize; + pj = pi; + pk = pi - 1; + while (pj > pl && cmp(vp, v + (*pk)*elsize, arr) < 0) { + *pj-- = *pk--; + } + *pj = vi; + } + } +} + + +int +npy_amergesort(char *v, npy_intp *tosort, npy_intp num, PyArrayObject *arr) +{ + npy_intp elsize = PyArray_ITEMSIZE(arr); + PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare; + npy_intp *pl, *pr, *pw; + + pl = tosort; + pr = pl + num; + pw = malloc((num >> 1) * sizeof(npy_intp)); + if (pw == NULL) { + return -NPY_ENOMEM; + } + npy_amergesort0(pl, pr, v, pw, elsize, cmp, arr); + free(pw); + + return 0; +} diff --git a/numpy/core/src/npysort/quicksort.c.src b/numpy/core/src/npysort/quicksort.c.src index a27530eb4..fcde70957 100644 --- a/numpy/core/src/npysort/quicksort.c.src +++ b/numpy/core/src/npysort/quicksort.c.src @@ -361,3 +361,81 @@ npy_quicksort(void *base, size_t num, size_t size, npy_comparator cmp) qsort(base, num, size, cmp); return 0; } + + +int +npy_aquicksort(char *v, npy_intp* tosort, npy_intp num, PyArrayObject *arr) +{ + npy_intp elsize = PyArray_ITEMSIZE(arr); + PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare; + char *vp; + npy_intp *pl = tosort; + npy_intp *pr = tosort + num - 1; + npy_intp *stack[PYA_QS_STACK]; + npy_intp **sptr = stack; + npy_intp *pm, *pi, *pj, *pk, vi; + + for (;;) { + while ((pr - pl) > SMALL_QUICKSORT) { + /* quicksort partition */ + pm = pl + ((pr - pl) >> 1); + if (cmp(v + (*pm)*elsize, v + (*pl)*elsize, arr) < 0) { + INTP_SWAP(*pm, *pl); + } + if (cmp(v + (*pr)*elsize, v + (*pm)*elsize, arr) < 0) { + INTP_SWAP(*pr, *pm); + } + if (cmp(v + (*pm)*elsize, v + (*pl)*elsize, arr) < 0) { + INTP_SWAP(*pm, *pl); + } + vp = v + (*pm)*elsize; + pi = pl; + pj = pr - 1; + INTP_SWAP(*pm,*pj); + for (;;) { + do { + ++pi; + } while (cmp(v + (*pi)*elsize, vp, arr) < 0); + do { + --pj; + } while (cmp(vp, v + (*pj)*elsize, arr) < 0); + if (pi >= pj) { + break; + } + INTP_SWAP(*pi,*pj); + } + pk = pr - 1; + INTP_SWAP(*pi,*pk); + /* push largest partition on stack */ + if (pi - pl < pr - pi) { + *sptr++ = pi + 1; + *sptr++ = pr; + pr = pi - 1; + } + else { + *sptr++ = pl; + *sptr++ = pi - 1; + pl = pi + 1; + } + } + + /* insertion sort */ + for (pi = pl + 1; pi <= pr; ++pi) { + vi = *pi; + vp = v + vi*elsize; + pj = pi; + pk = pi - 1; + while (pj > pl && cmp(vp, v + (*pk)*elsize, arr) < 0) { + *pj-- = *pk--; + } + *pj = vi; + } + if (sptr == stack) { + break; + } + pr = *(--sptr); + pl = *(--sptr); + } + + return 0; +} diff --git a/numpy/core/src/private/npy_sort.h b/numpy/core/src/private/npy_sort.h index 46825a0c5..7c5701f88 100644 --- a/numpy/core/src/private/npy_sort.h +++ b/numpy/core/src/private/npy_sort.h @@ -190,5 +190,8 @@ int amergesort_timedelta(npy_timedelta *vec, npy_intp *ind, npy_intp cnt, void * int npy_quicksort(void *base, size_t num, size_t size, npy_comparator cmp); int npy_heapsort(void *base, size_t num, size_t size, npy_comparator cmp); int npy_mergesort(void *base, size_t num, size_t size, npy_comparator cmp); +int npy_aquicksort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr); +int npy_aheapsort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr); +int npy_amergesort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr); #endif diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b0d677052..398852d89 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1417,6 +1417,29 @@ class TestMethods(TestCase): b = a.searchsorted(a, 'r', s) assert_equal(b, out + 1) + def test_argpartition_out_of_range(self): + # Test out of range values in kth raise an error, gh-5469 + d = np.arange(10) + assert_raises(ValueError, d.argpartition, 10) + assert_raises(ValueError, d.argpartition, -11) + # Test also for generic type argpartition, which uses sorting + # and used to not bound check kth + d_obj = np.arange(10, dtype=object) + assert_raises(ValueError, d_obj.argpartition, 10) + assert_raises(ValueError, d_obj.argpartition, -11) + + @dec.knownfailureif(True, "Ticket #5469 fixed for argpartition only") + def test_partition_out_of_range(self): + # Test out of range values in kth raise an error, gh-5469 + d = np.arange(10) + assert_raises(ValueError, d.partition, 10) + assert_raises(ValueError, d.partition, -11) + # Test also for generic type partition, which uses sorting + # and used to not bound check kth + d_obj = np.arange(10, dtype=object) + assert_raises(ValueError, d_obj.partition, 10) + assert_raises(ValueError, d_obj.partition, -11) + def test_partition(self): d = np.arange(10) assert_raises(TypeError, np.partition, d, 2, kind=1) |