diff options
author | Hameer Abbasi <hameerabbasi@yahoo.com> | 2019-09-04 10:26:10 +0200 |
---|---|---|
committer | Hameer Abbasi <hameerabbasi@yahoo.com> | 2019-09-04 10:44:46 +0200 |
commit | 003a3d4c5056d698e62f24c392c101bf687e00bd (patch) | |
tree | ac48c7060a7a5c5f8ce82ece7fc3c2204c81fded /numpy | |
parent | e3f4c536a0014789dbd0321926b4f62c39d73719 (diff) | |
download | numpy-003a3d4c5056d698e62f24c392c101bf687e00bd.tar.gz |
BUG: Fix lexsort + radix sort.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/npysort/radixsort.c.src | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 22 |
2 files changed, 17 insertions, 9 deletions
diff --git a/numpy/core/src/npysort/radixsort.c.src b/numpy/core/src/npysort/radixsort.c.src index c90b06974..72887d7e4 100644 --- a/numpy/core/src/npysort/radixsort.c.src +++ b/numpy/core/src/npysort/radixsort.c.src @@ -198,9 +198,9 @@ aradixsort_@suff@(void *start, npy_intp* tosort, npy_intp num, void *NPY_UNUSED( return 0; } - k1 = KEY_OF(arr[0]); + k1 = KEY_OF(arr[tosort[0]]); for (npy_intp i = 1; i < num; i++) { - k2 = KEY_OF(arr[i]); + k2 = KEY_OF(arr[tosort[i]]); if (k1 > k2) { all_sorted = 0; break; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 2593045ed..1b698b517 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4587,18 +4587,26 @@ class TestTake(object): assert_equal(y, np.array([1, 2, 3])) class TestLexsort(object): - def test_basic(self): - a = [1, 2, 1, 3, 1, 5] - b = [0, 4, 5, 6, 2, 3] + @pytest.mark.parametrize('dtype',[ + np.uint8, np.uint16, np.uint32, np.uint64, + np.int8, np.int16, np.int32, np.int64, + np.float16, np.float32, np.float64 + ]) + def test_basic(self, dtype): + a = np.array([1, 2, 1, 3, 1, 5], dtype=dtype) + b = np.array([0, 4, 5, 6, 2, 3], dtype=dtype) idx = np.lexsort((b, a)) expected_idx = np.array([0, 4, 2, 1, 3, 5]) assert_array_equal(idx, expected_idx) + assert_array_equal(a[idx], np.sort(a)) - x = np.vstack((b, a)) - idx = np.lexsort(x) - assert_array_equal(idx, expected_idx) + def test_mixed(self): + a = np.array([1, 2, 1, 3, 1, 5]) + b = np.array([0, 4, 5, 6, 2, 3], dtype='datetime64[D]') - assert_array_equal(x[1][idx], np.sort(x[1])) + idx = np.lexsort((b, a)) + expected_idx = np.array([0, 4, 2, 1, 3, 5]) + assert_array_equal(idx, expected_idx) def test_datetime(self): a = np.array([0,0,0], dtype='datetime64[D]') |