summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2008-02-13 21:41:47 +0000
committerCharles Harris <charlesr.harris@gmail.com>2008-02-13 21:41:47 +0000
commite7c535848319bad6783e3ded5bbf0c14658e560e (patch)
tree3da0b4c1dab3c1dd3aeac6fec03cf7e28f7ac94c
parent4ded091d27c9744c2bf09a6882b890f85e574e05 (diff)
downloadnumpy-e7c535848319bad6783e3ded5bbf0c14658e560e.tar.gz
Add type specific heapsort to argsort kind options for strings and unicode.
-rw-r--r--numpy/core/src/_sortmodule.c.src51
-rw-r--r--numpy/core/tests/test_multiarray.py10
2 files changed, 55 insertions, 6 deletions
diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src
index 547c560fd..0dc21a556 100644
--- a/numpy/core/src/_sortmodule.c.src
+++ b/numpy/core/src/_sortmodule.c.src
@@ -520,6 +520,53 @@ static int
static int
+@TYPE@_aheapsort(@type@ *v, intp *tosort, intp n, PyArrayObject *arr)
+{
+ intp *a, i,j,l, tmp;
+ size_t len = arr->descr->elsize/sizeof(@type@);
+
+ /* The array needs to be offset by one for heapsort indexing */
+ a = tosort - 1;
+
+ for (l = n>>1; l > 0; --l) {
+ tmp = a[l];
+ for (i = l, j = l<<1; j <= n;) {
+ if (j < n && @lessthan@(v + a[j]*len, v + a[j+1]*len, len))
+ j += 1;
+ if (@lessthan@(v + tmp*len, v + a[j]*len, len)) {
+ a[i] = a[j];
+ i = j;
+ j += j;
+ }
+ else
+ break;
+ }
+ a[i] = tmp;
+ }
+
+ for (; n > 1;) {
+ tmp = a[n];
+ a[n] = a[1];
+ n -= 1;
+ for (i = 1, j = 2; j <= n;) {
+ if (j < n && @lessthan@(v + a[j]*len, v + a[j+1]*len, len))
+ j++;
+ if (@lessthan@(v + tmp*len, v + a[j]*len, len)) {
+ a[i] = a[j];
+ i = j;
+ j += j;
+ }
+ else
+ break;
+ }
+ a[i] = tmp;
+ }
+
+ return 0;
+}
+
+
+static int
@TYPE@_aquicksort(@type@ *v, intp* tosort, intp num, PyArrayObject *arr)
{
@type@ *vp;
@@ -674,6 +721,8 @@ add_sortfuncs(void)
(PyArray_ArgSortFunc *)STRING_amergesort;
descr->f->argsort[PyArray_QUICKSORT] = \
(PyArray_ArgSortFunc *)STRING_aquicksort;
+ descr->f->argsort[PyArray_HEAPSORT] = \
+ (PyArray_ArgSortFunc *)STRING_aheapsort;
descr->f->sort[PyArray_QUICKSORT] = \
(PyArray_SortFunc *)STRING_quicksort;
@@ -682,6 +731,8 @@ add_sortfuncs(void)
(PyArray_ArgSortFunc *)UNICODE_amergesort;
descr->f->argsort[PyArray_QUICKSORT] = \
(PyArray_ArgSortFunc *)UNICODE_aquicksort;
+ descr->f->argsort[PyArray_HEAPSORT] = \
+ (PyArray_ArgSortFunc *)UNICODE_aheapsort;
descr->f->sort[PyArray_QUICKSORT] = \
(PyArray_SortFunc *)UNICODE_quicksort;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 9ed12d97e..8959841ca 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -365,26 +365,24 @@ class TestMethods(NumpyTestCase):
assert_equal(ai.copy().argsort(kind=kind), a, msg)
assert_equal(bi.copy().argsort(kind=kind), b, msg)
- # test string argsorts. Only quick and merge sort are
- # available at this time.
+ # test string argsorts.
s = 'aaaaaaaa'
a = np.array([s + chr(i) for i in range(100)])
b = a[::-1].copy()
r = arange(100)
rr = r[::-1].copy()
- for kind in ['q', 'm'] :
+ for kind in ['q', 'm', 'h'] :
msg = "string argsort, kind=%s" % kind
assert_equal(a.copy().argsort(kind=kind), r, msg)
assert_equal(b.copy().argsort(kind=kind), rr, msg)
- # test unicode argsorts. Only quick and merge sort are
- # available at this time.
+ # test unicode argsorts.
s = 'aaaaaaaa'
a = np.array([s + chr(i) for i in range(100)], dtype=np.unicode)
b = a[::-1].copy()
r = arange(100)
rr = r[::-1].copy()
- for kind in ['q', 'm'] :
+ for kind in ['q', 'm', 'h'] :
msg = "unicode argsort, kind=%s" % kind
assert_equal(a.copy().argsort(kind=kind), r, msg)
assert_equal(b.copy().argsort(kind=kind), rr, msg)