summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-02-24 13:07:56 -0500
committerCharles Harris <charlesr.harris@gmail.com>2015-02-24 13:07:56 -0500
commit4d4fb0d12b9ddfa76e485ea3caa6fac25e7183da (patch)
tree907da5b3488dd83e06fbd0dbd339283e4baae39a /numpy/core
parent8c0a50c5b1b331738a8bfb6a94ccc588851d1a0c (diff)
parenta0aaf752c75c310d4eeae1af1064ec8046e6d5e9 (diff)
downloadnumpy-4d4fb0d12b9ddfa76e485ea3caa6fac25e7183da.tar.gz
Merge pull request #5494 from jaimefrio/argsort_refactor
MAINT: refactor PyArray_ArgSort and PyArrayArgPartition
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/item_selection.c298
-rw-r--r--numpy/core/src/npysort/heapsort.c.src52
-rw-r--r--numpy/core/src/npysort/mergesort.c.src66
-rw-r--r--numpy/core/src/npysort/quicksort.c.src78
-rw-r--r--numpy/core/src/private/npy_sort.h3
-rw-r--r--numpy/core/tests/test_multiarray.py23
6 files changed, 283 insertions, 237 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 7016760b3..e12cd1311 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1421,146 +1421,54 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SEL
}
-static char *global_data;
-
-static int
-argsort_static_compare(const void *ip1, const void *ip2)
-{
- int isize = PyArray_DESCR(global_obj)->elsize;
- const npy_intp *ipa = ip1;
- const npy_intp *ipb = ip2;
- return PyArray_DESCR(global_obj)->f->compare(global_data + (isize * *ipa),
- global_data + (isize * *ipb),
- global_obj);
-}
-
/*NUMPY_API
* ArgSort an array
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
- PyArrayObject *ap = NULL, *ret = NULL, *store, *op2;
- npy_intp *ip;
- npy_intp i, j, n, m, orign;
- int argsort_elsize;
- char *store_ptr;
- int res = 0;
- int (*sort)(void *, size_t, size_t, npy_comparator);
-
- n = PyArray_NDIM(op);
- if ((n == 0) || (PyArray_SIZE(op) == 1)) {
- ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op),
- PyArray_DIMS(op),
- NPY_INTP,
- NULL, NULL, 0, 0,
- (PyObject *)op);
- if (ret == NULL) {
- return NULL;
- }
- *((npy_intp *)PyArray_DATA(ret)) = 0;
- return (PyObject *)ret;
- }
+ PyArrayObject *op2;
+ PyArray_ArgSortFunc *argsort;
+ PyObject *ret;
- /* Creates new reference op2 */
- if ((op2=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
+ if (which < 0 || which >= NPY_NSORTS) {
+ PyErr_SetString(PyExc_ValueError,
+ "not a valid sort kind");
return NULL;
}
- /* Determine if we should use new algorithm or not */
- if (PyArray_DESCR(op2)->f->argsort[which] != NULL) {
- ret = (PyArrayObject *)_new_argsortlike(op2, axis,
- PyArray_DESCR(op2)->f->argsort[which],
- NULL, NULL, 0);
- Py_DECREF(op2);
- return (PyObject *)ret;
- }
-
- if (PyArray_DESCR(op2)->f->compare == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "type does not have compare function");
- Py_DECREF(op2);
- op = NULL;
- goto fail;
- }
-
- switch (which) {
- case NPY_QUICKSORT :
- sort = npy_quicksort;
- break;
- case NPY_HEAPSORT :
- sort = npy_heapsort;
- break;
- case NPY_MERGESORT :
- sort = npy_mergesort;
- break;
- default:
- PyErr_SetString(PyExc_TypeError,
- "requested sort kind is not supported");
- Py_DECREF(op2);
- op = NULL;
- goto fail;
- }
- /* ap will contain the reference to op2 */
- SWAPAXES(ap, op2);
- op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap,
- NPY_NOTYPE,
- 1, 0);
- Py_DECREF(ap);
- if (op == NULL) {
- return NULL;
- }
- ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op),
- PyArray_DIMS(op), NPY_INTP,
- NULL, NULL, 0, 0, (PyObject *)op);
- if (ret == NULL) {
- goto fail;
- }
- ip = (npy_intp *)PyArray_DATA(ret);
- argsort_elsize = PyArray_DESCR(op)->elsize;
- m = PyArray_DIMS(op)[PyArray_NDIM(op)-1];
- if (m == 0) {
- goto finish;
- }
- n = PyArray_SIZE(op)/m;
- store_ptr = global_data;
- global_data = PyArray_DATA(op);
- store = global_obj;
- global_obj = op;
- for (i = 0; i < n; i++, ip += m, global_data += m*argsort_elsize) {
- for (j = 0; j < m; j++) {
- ip[j] = j;
+ argsort = PyArray_DESCR(op)->f->argsort[which];
+ if (argsort == NULL) {
+ if (PyArray_DESCR(op)->f->compare) {
+ switch (which) {
+ default:
+ case NPY_QUICKSORT:
+ argsort = npy_aquicksort;
+ break;
+ case NPY_HEAPSORT:
+ argsort = npy_aheapsort;
+ break;
+ case NPY_MERGESORT:
+ argsort = npy_amergesort;
+ break;
+ }
}
- res = sort((char *)ip, m, sizeof(npy_intp), argsort_static_compare);
- if (res < 0) {
- break;
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "type does not have compare function");
+ return NULL;
}
}
- global_data = store_ptr;
- global_obj = store;
- if (PyErr_Occurred()) {
- goto fail;
- }
- else if (res == -NPY_ENOMEM) {
- PyErr_NoMemory();
- goto fail;
- }
- else if (res == -NPY_ECOMP) {
- PyErr_SetString(PyExc_TypeError,
- "sort comparison failed");
- goto fail;
+ op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0);
+ if (op2 == NULL) {
+ return NULL;
}
- finish:
- Py_DECREF(op);
- SWAPBACK(op, ret);
- return (PyObject *)op;
+ ret = _new_argsortlike(op2, axis, argsort, NULL, NULL, 0);
- fail:
- Py_XDECREF(op);
- Py_XDECREF(ret);
- return NULL;
+ Py_DECREF(op2);
+ return ret;
}
@@ -1568,136 +1476,52 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which)
* ArgPartition an array
*/
NPY_NO_EXPORT PyObject *
-PyArray_ArgPartition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SELECTKIND which)
+PyArray_ArgPartition(PyArrayObject *op, PyArrayObject *ktharray, int axis,
+ NPY_SELECTKIND which)
{
- PyArrayObject *ap = NULL, *ret = NULL, *store, *op2;
- npy_intp *ip;
- npy_intp i, j, n, m, orign;
- int argsort_elsize;
- char *store_ptr;
- int res = 0;
- int (*sort)(void *, size_t, size_t, npy_comparator);
- PyArray_ArgPartitionFunc * argpart =
- get_argpartition_func(PyArray_TYPE(op), which);
-
- n = PyArray_NDIM(op);
- if ((n == 0) || (PyArray_SIZE(op) == 1)) {
- ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op),
- PyArray_DIMS(op),
- NPY_INTP,
- NULL, NULL, 0, 0,
- (PyObject *)op);
- if (ret == NULL) {
- return NULL;
- }
- *((npy_intp *)PyArray_DATA(ret)) = 0;
- return (PyObject *)ret;
- }
+ PyArrayObject *op2, *kthrvl;
+ PyArray_ArgPartitionFunc *argpart;
+ PyArray_ArgSortFunc *argsort;
+ PyObject *ret;
- /* Creates new reference op2 */
- if ((op2=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
+ if (which < 0 || which >= NPY_NSELECTS) {
+ PyErr_SetString(PyExc_ValueError,
+ "not a valid partition kind");
return NULL;
}
- /* Determine if we should use new algorithm or not */
- if (argpart) {
- PyArrayObject * kthrvl = partition_prep_kth_array(ktharray, op2, axis);
- if (kthrvl == NULL) {
- Py_DECREF(op2);
+ argpart = get_argpartition_func(PyArray_TYPE(op), which);
+ if (argpart == NULL) {
+ /* Use sorting, slower but equivalent */
+ if (PyArray_DESCR(op)->f->compare) {
+ argsort = npy_aquicksort;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "type does not have compare function");
return NULL;
}
-
- ret = (PyArrayObject *)_new_argsortlike(op2, axis, NULL,
- argpart,
- PyArray_DATA(kthrvl),
- PyArray_SIZE(kthrvl));
- Py_DECREF(kthrvl);
- Py_DECREF(op2);
- return (PyObject *)ret;
- }
-
- if (PyArray_DESCR(op2)->f->compare == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "type does not have compare function");
- Py_DECREF(op2);
- op = NULL;
- goto fail;
}
- /* select not implemented, use quicksort, slower but equivalent */
- switch (which) {
- case NPY_INTROSELECT :
- sort = npy_quicksort;
- break;
- default:
- PyErr_SetString(PyExc_TypeError,
- "requested sort kind is not supported");
- Py_DECREF(op2);
- op = NULL;
- goto fail;
+ op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0);
+ if (op2 == NULL) {
+ return NULL;
}
- /* ap will contain the reference to op2 */
- SWAPAXES(ap, op2);
- op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap,
- NPY_NOTYPE,
- 1, 0);
- Py_DECREF(ap);
- if (op == NULL) {
+ /* Process ktharray even if using sorting to do bounds checking */
+ kthrvl = partition_prep_kth_array(ktharray, op2, axis);
+ if (kthrvl == NULL) {
+ Py_DECREF(op2);
return NULL;
}
- ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op),
- PyArray_DIMS(op), NPY_INTP,
- NULL, NULL, 0, 0, (PyObject *)op);
- if (ret == NULL) {
- goto fail;
- }
- ip = (npy_intp *)PyArray_DATA(ret);
- argsort_elsize = PyArray_DESCR(op)->elsize;
- m = PyArray_DIMS(op)[PyArray_NDIM(op)-1];
- if (m == 0) {
- goto finish;
- }
- n = PyArray_SIZE(op)/m;
- store_ptr = global_data;
- global_data = PyArray_DATA(op);
- store = global_obj;
- global_obj = op;
- /* we don't need to care about kth here as we are using a full sort */
- for (i = 0; i < n; i++, ip += m, global_data += m*argsort_elsize) {
- for (j = 0; j < m; j++) {
- ip[j] = j;
- }
- res = sort((char *)ip, m, sizeof(npy_intp), argsort_static_compare);
- if (res < 0) {
- break;
- }
- }
- global_data = store_ptr;
- global_obj = store;
- if (PyErr_Occurred()) {
- goto fail;
- }
- else if (res == -NPY_ENOMEM) {
- PyErr_NoMemory();
- goto fail;
- }
- else if (res == -NPY_ECOMP) {
- PyErr_SetString(PyExc_TypeError,
- "sort comparison failed");
- goto fail;
- }
+ ret = _new_argsortlike(op2, axis, argsort, argpart,
+ PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl));
- finish:
- Py_DECREF(op);
- SWAPBACK(op, ret);
- return (PyObject *)op;
+ Py_DECREF(kthrvl);
+ Py_DECREF(op2);
- fail:
- Py_XDECREF(op);
- Py_XDECREF(ret);
- return NULL;
+ return ret;
}
diff --git a/numpy/core/src/npysort/heapsort.c.src b/numpy/core/src/npysort/heapsort.c.src
index cfdd3fd2a..bec8cc032 100644
--- a/numpy/core/src/npysort/heapsort.c.src
+++ b/numpy/core/src/npysort/heapsort.c.src
@@ -343,3 +343,55 @@ npy_heapsort(void *base, size_t num, size_t size, npy_comparator cmp)
free(tmp);
return 0;
}
+
+
+int
+npy_aheapsort(char *v, npy_intp *tosort, npy_intp n, PyArrayObject *arr)
+{
+ npy_intp elsize = PyArray_ITEMSIZE(arr);
+ PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare;
+ npy_intp *a, i, j, l, tmp;
+
+ /* The array needs to be offset by one for heapsort indexing */
+ a = tosort - 1;
+
+ for (l = n >> 1; l > 0; --l) {
+ tmp = a[l];
+ for (i = l, j = l<<1; j <= n;) {
+ if (j < n && cmp(v + a[j]*elsize, v + a[j+1]*elsize, arr) < 0) {
+ ++j;
+ }
+ if (cmp(v + tmp*elsize, v + a[j]*elsize, arr) < 0) {
+ a[i] = a[j];
+ i = j;
+ j += j;
+ }
+ else {
+ break;
+ }
+ }
+ a[i] = tmp;
+ }
+
+ for (; n > 1;) {
+ tmp = a[n];
+ a[n] = a[1];
+ n -= 1;
+ for (i = 1, j = 2; j <= n;) {
+ if (j < n && cmp(v + a[j]*elsize, v + a[j+1]*elsize, arr) < 0) {
+ ++j;
+ }
+ if (cmp(v + tmp*elsize, v + a[j]*elsize, arr) < 0) {
+ a[i] = a[j];
+ i = j;
+ j += j;
+ }
+ else {
+ break;
+ }
+ }
+ a[i] = tmp;
+ }
+
+ return 0;
+}
diff --git a/numpy/core/src/npysort/mergesort.c.src b/numpy/core/src/npysort/mergesort.c.src
index c99c0e614..ecf39720a 100644
--- a/numpy/core/src/npysort/mergesort.c.src
+++ b/numpy/core/src/npysort/mergesort.c.src
@@ -426,3 +426,69 @@ fail_1:
fail_0:
return err;
}
+
+
+static void
+npy_amergesort0(npy_intp *pl, npy_intp *pr, char *v, npy_intp *pw,
+ npy_intp elsize, PyArray_CompareFunc *cmp, PyArrayObject *arr)
+{
+ char *vp;
+ npy_intp vi, *pi, *pj, *pk, *pm;
+
+ if (pr - pl > SMALL_MERGESORT) {
+ /* merge sort */
+ pm = pl + ((pr - pl) >> 1);
+ npy_amergesort0(pl, pm, v, pw, elsize, cmp, arr);
+ npy_amergesort0(pm, pr, v, pw, elsize, cmp, arr);
+ for (pi = pw, pj = pl; pj < pm;) {
+ *pi++ = *pj++;
+ }
+ pi = pw + (pm - pl);
+ pj = pw;
+ pk = pl;
+ while (pj < pi && pm < pr) {
+ if (cmp(v + (*pm)*elsize, v + (*pj)*elsize, arr) < 0) {
+ *pk++ = *pm++;
+ }
+ else {
+ *pk++ = *pj++;
+ }
+ }
+ while (pj < pi) {
+ *pk++ = *pj++;
+ }
+ }
+ else {
+ /* insertion sort */
+ for (pi = pl + 1; pi < pr; ++pi) {
+ vi = *pi;
+ vp = v + vi*elsize;
+ pj = pi;
+ pk = pi - 1;
+ while (pj > pl && cmp(vp, v + (*pk)*elsize, arr) < 0) {
+ *pj-- = *pk--;
+ }
+ *pj = vi;
+ }
+ }
+}
+
+
+int
+npy_amergesort(char *v, npy_intp *tosort, npy_intp num, PyArrayObject *arr)
+{
+ npy_intp elsize = PyArray_ITEMSIZE(arr);
+ PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare;
+ npy_intp *pl, *pr, *pw;
+
+ pl = tosort;
+ pr = pl + num;
+ pw = malloc((num >> 1) * sizeof(npy_intp));
+ if (pw == NULL) {
+ return -NPY_ENOMEM;
+ }
+ npy_amergesort0(pl, pr, v, pw, elsize, cmp, arr);
+ free(pw);
+
+ return 0;
+}
diff --git a/numpy/core/src/npysort/quicksort.c.src b/numpy/core/src/npysort/quicksort.c.src
index a27530eb4..fcde70957 100644
--- a/numpy/core/src/npysort/quicksort.c.src
+++ b/numpy/core/src/npysort/quicksort.c.src
@@ -361,3 +361,81 @@ npy_quicksort(void *base, size_t num, size_t size, npy_comparator cmp)
qsort(base, num, size, cmp);
return 0;
}
+
+
+int
+npy_aquicksort(char *v, npy_intp* tosort, npy_intp num, PyArrayObject *arr)
+{
+ npy_intp elsize = PyArray_ITEMSIZE(arr);
+ PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare;
+ char *vp;
+ npy_intp *pl = tosort;
+ npy_intp *pr = tosort + num - 1;
+ npy_intp *stack[PYA_QS_STACK];
+ npy_intp **sptr = stack;
+ npy_intp *pm, *pi, *pj, *pk, vi;
+
+ for (;;) {
+ while ((pr - pl) > SMALL_QUICKSORT) {
+ /* quicksort partition */
+ pm = pl + ((pr - pl) >> 1);
+ if (cmp(v + (*pm)*elsize, v + (*pl)*elsize, arr) < 0) {
+ INTP_SWAP(*pm, *pl);
+ }
+ if (cmp(v + (*pr)*elsize, v + (*pm)*elsize, arr) < 0) {
+ INTP_SWAP(*pr, *pm);
+ }
+ if (cmp(v + (*pm)*elsize, v + (*pl)*elsize, arr) < 0) {
+ INTP_SWAP(*pm, *pl);
+ }
+ vp = v + (*pm)*elsize;
+ pi = pl;
+ pj = pr - 1;
+ INTP_SWAP(*pm,*pj);
+ for (;;) {
+ do {
+ ++pi;
+ } while (cmp(v + (*pi)*elsize, vp, arr) < 0);
+ do {
+ --pj;
+ } while (cmp(vp, v + (*pj)*elsize, arr) < 0);
+ if (pi >= pj) {
+ break;
+ }
+ INTP_SWAP(*pi,*pj);
+ }
+ pk = pr - 1;
+ INTP_SWAP(*pi,*pk);
+ /* push largest partition on stack */
+ if (pi - pl < pr - pi) {
+ *sptr++ = pi + 1;
+ *sptr++ = pr;
+ pr = pi - 1;
+ }
+ else {
+ *sptr++ = pl;
+ *sptr++ = pi - 1;
+ pl = pi + 1;
+ }
+ }
+
+ /* insertion sort */
+ for (pi = pl + 1; pi <= pr; ++pi) {
+ vi = *pi;
+ vp = v + vi*elsize;
+ pj = pi;
+ pk = pi - 1;
+ while (pj > pl && cmp(vp, v + (*pk)*elsize, arr) < 0) {
+ *pj-- = *pk--;
+ }
+ *pj = vi;
+ }
+ if (sptr == stack) {
+ break;
+ }
+ pr = *(--sptr);
+ pl = *(--sptr);
+ }
+
+ return 0;
+}
diff --git a/numpy/core/src/private/npy_sort.h b/numpy/core/src/private/npy_sort.h
index 46825a0c5..7c5701f88 100644
--- a/numpy/core/src/private/npy_sort.h
+++ b/numpy/core/src/private/npy_sort.h
@@ -190,5 +190,8 @@ int amergesort_timedelta(npy_timedelta *vec, npy_intp *ind, npy_intp cnt, void *
int npy_quicksort(void *base, size_t num, size_t size, npy_comparator cmp);
int npy_heapsort(void *base, size_t num, size_t size, npy_comparator cmp);
int npy_mergesort(void *base, size_t num, size_t size, npy_comparator cmp);
+int npy_aquicksort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr);
+int npy_aheapsort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr);
+int npy_amergesort(char *vec, npy_intp *ind, npy_intp cnt, PyArrayObject *arr);
#endif
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b0d677052..398852d89 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1417,6 +1417,29 @@ class TestMethods(TestCase):
b = a.searchsorted(a, 'r', s)
assert_equal(b, out + 1)
+ def test_argpartition_out_of_range(self):
+ # Test out of range values in kth raise an error, gh-5469
+ d = np.arange(10)
+ assert_raises(ValueError, d.argpartition, 10)
+ assert_raises(ValueError, d.argpartition, -11)
+ # Test also for generic type argpartition, which uses sorting
+ # and used to not bound check kth
+ d_obj = np.arange(10, dtype=object)
+ assert_raises(ValueError, d_obj.argpartition, 10)
+ assert_raises(ValueError, d_obj.argpartition, -11)
+
+ @dec.knownfailureif(True, "Ticket #5469 fixed for argpartition only")
+ def test_partition_out_of_range(self):
+ # Test out of range values in kth raise an error, gh-5469
+ d = np.arange(10)
+ assert_raises(ValueError, d.partition, 10)
+ assert_raises(ValueError, d.partition, -11)
+ # Test also for generic type partition, which uses sorting
+ # and used to not bound check kth
+ d_obj = np.arange(10, dtype=object)
+ assert_raises(ValueError, d_obj.partition, 10)
+ assert_raises(ValueError, d_obj.partition, -11)
+
def test_partition(self):
d = np.arange(10)
assert_raises(TypeError, np.partition, d, 2, kind=1)