summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorHameer Abbasi <hameerabbasi@yahoo.com>2019-09-04 10:26:10 +0200
committerHameer Abbasi <hameerabbasi@yahoo.com>2019-09-04 10:44:46 +0200
commit003a3d4c5056d698e62f24c392c101bf687e00bd (patch)
treeac48c7060a7a5c5f8ce82ece7fc3c2204c81fded /numpy
parente3f4c536a0014789dbd0321926b4f62c39d73719 (diff)
downloadnumpy-003a3d4c5056d698e62f24c392c101bf687e00bd.tar.gz
BUG: Fix lexsort + radix sort.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npysort/radixsort.c.src4
-rw-r--r--numpy/core/tests/test_multiarray.py22
2 files changed, 17 insertions, 9 deletions
diff --git a/numpy/core/src/npysort/radixsort.c.src b/numpy/core/src/npysort/radixsort.c.src
index c90b06974..72887d7e4 100644
--- a/numpy/core/src/npysort/radixsort.c.src
+++ b/numpy/core/src/npysort/radixsort.c.src
@@ -198,9 +198,9 @@ aradixsort_@suff@(void *start, npy_intp* tosort, npy_intp num, void *NPY_UNUSED(
return 0;
}
- k1 = KEY_OF(arr[0]);
+ k1 = KEY_OF(arr[tosort[0]]);
for (npy_intp i = 1; i < num; i++) {
- k2 = KEY_OF(arr[i]);
+ k2 = KEY_OF(arr[tosort[i]]);
if (k1 > k2) {
all_sorted = 0;
break;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 2593045ed..1b698b517 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4587,18 +4587,26 @@ class TestTake(object):
assert_equal(y, np.array([1, 2, 3]))
class TestLexsort(object):
- def test_basic(self):
- a = [1, 2, 1, 3, 1, 5]
- b = [0, 4, 5, 6, 2, 3]
+ @pytest.mark.parametrize('dtype',[
+ np.uint8, np.uint16, np.uint32, np.uint64,
+ np.int8, np.int16, np.int32, np.int64,
+ np.float16, np.float32, np.float64
+ ])
+ def test_basic(self, dtype):
+ a = np.array([1, 2, 1, 3, 1, 5], dtype=dtype)
+ b = np.array([0, 4, 5, 6, 2, 3], dtype=dtype)
idx = np.lexsort((b, a))
expected_idx = np.array([0, 4, 2, 1, 3, 5])
assert_array_equal(idx, expected_idx)
+ assert_array_equal(a[idx], np.sort(a))
- x = np.vstack((b, a))
- idx = np.lexsort(x)
- assert_array_equal(idx, expected_idx)
+ def test_mixed(self):
+ a = np.array([1, 2, 1, 3, 1, 5])
+ b = np.array([0, 4, 5, 6, 2, 3], dtype='datetime64[D]')
- assert_array_equal(x[1][idx], np.sort(x[1]))
+ idx = np.lexsort((b, a))
+ expected_idx = np.array([0, 4, 2, 1, 3, 5])
+ assert_array_equal(idx, expected_idx)
def test_datetime(self):
a = np.array([0,0,0], dtype='datetime64[D]')