summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npysort/radixsort.cpp82
-rw-r--r--numpy/ma/tests/test_core.py4
2 files changed, 46 insertions, 40 deletions
diff --git a/numpy/core/src/npysort/radixsort.cpp b/numpy/core/src/npysort/radixsort.cpp
index 017ea43b6..5393869ee 100644
--- a/numpy/core/src/npysort/radixsort.cpp
+++ b/numpy/core/src/npysort/radixsort.cpp
@@ -14,9 +14,9 @@
*/
// Reference: https://github.com/eloj/radix-sorting#-key-derivation
-template <class T>
-T
-KEY_OF(T x)
+template <class T, class UT>
+UT
+KEY_OF(UT x)
{
// Floating-point is currently disabled.
// Floating-point tests succeed for double and float on macOS but not on
@@ -27,12 +27,12 @@ KEY_OF(T x)
// For floats, we invert the key if the sign bit is set, else we invert
// the sign bit.
return ((x) ^ (-((x) >> (sizeof(T) * 8 - 1)) |
- ((T)1 << (sizeof(T) * 8 - 1))));
+ ((UT)1 << (sizeof(T) * 8 - 1))));
}
else if (std::is_signed<T>::value) {
// For signed ints, we flip the sign bit so the negatives are below the
// positives.
- return ((x) ^ ((T)1 << (sizeof(T) * 8 - 1)));
+ return ((x) ^ ((UT)1 << (sizeof(UT) * 8 - 1)));
}
else {
return x;
@@ -46,24 +46,24 @@ nth_byte(T key, npy_intp l)
return (key >> (l << 3)) & 0xFF;
}
-template <class T>
-static T *
-radixsort0(T *start, T *aux, npy_intp num)
+template <class T, class UT>
+static UT *
+radixsort0(UT *start, UT *aux, npy_intp num)
{
- npy_intp cnt[sizeof(T)][1 << 8] = {{0}};
- T key0 = KEY_OF(start[0]);
+ npy_intp cnt[sizeof(UT)][1 << 8] = {{0}};
+ UT key0 = KEY_OF<T>(start[0]);
for (npy_intp i = 0; i < num; i++) {
- T k = KEY_OF(start[i]);
+ UT k = KEY_OF<T>(start[i]);
- for (size_t l = 0; l < sizeof(T); l++) {
+ for (size_t l = 0; l < sizeof(UT); l++) {
cnt[l][nth_byte(k, l)]++;
}
}
size_t ncols = 0;
- npy_ubyte cols[sizeof(T)];
- for (size_t l = 0; l < sizeof(T); l++) {
+ npy_ubyte cols[sizeof(UT)];
+ for (size_t l = 0; l < sizeof(UT); l++) {
if (cnt[l][nth_byte(key0, l)] != num) {
cols[ncols++] = l;
}
@@ -79,9 +79,9 @@ radixsort0(T *start, T *aux, npy_intp num)
}
for (size_t l = 0; l < ncols; l++) {
- T *temp;
+ UT *temp;
for (npy_intp i = 0; i < num; i++) {
- T k = KEY_OF(start[i]);
+ UT k = KEY_OF<T>(start[i]);
npy_intp dst = cnt[cols[l]][nth_byte(k, cols[l])]++;
aux[dst] = start[i];
}
@@ -94,18 +94,18 @@ radixsort0(T *start, T *aux, npy_intp num)
return start;
}
-template <class T>
+template <class T, class UT>
static int
-radixsort_(T *start, npy_intp num)
+radixsort_(UT *start, npy_intp num)
{
if (num < 2) {
return 0;
}
npy_bool all_sorted = 1;
- T k1 = KEY_OF(start[0]), k2;
+ UT k1 = KEY_OF<T>(start[0]);
for (npy_intp i = 1; i < num; i++) {
- k2 = KEY_OF(start[i]);
+ UT k2 = KEY_OF<T>(start[i]);
if (k1 > k2) {
all_sorted = 0;
break;
@@ -117,14 +117,14 @@ radixsort_(T *start, npy_intp num)
return 0;
}
- T *aux = (T *)malloc(num * sizeof(T));
+ UT *aux = (UT *)malloc(num * sizeof(UT));
if (aux == nullptr) {
return -NPY_ENOMEM;
}
- T *sorted = radixsort0(start, aux, num);
+ UT *sorted = radixsort0<T>(start, aux, num);
if (sorted != start) {
- memcpy(start, sorted, num * sizeof(T));
+ memcpy(start, sorted, num * sizeof(UT));
}
free(aux);
@@ -135,27 +135,28 @@ template <class T>
static int
radixsort(void *start, npy_intp num)
{
- return radixsort_((T *)start, num);
+ using UT = typename std::make_unsigned<T>::type;
+ return radixsort_<T>((UT *)start, num);
}
-template <class T>
+template <class T, class UT>
static npy_intp *
-aradixsort0(T *start, npy_intp *aux, npy_intp *tosort, npy_intp num)
+aradixsort0(UT *start, npy_intp *aux, npy_intp *tosort, npy_intp num)
{
- npy_intp cnt[sizeof(T)][1 << 8] = {{0}};
- T key0 = KEY_OF(start[0]);
+ npy_intp cnt[sizeof(UT)][1 << 8] = {{0}};
+ UT key0 = KEY_OF<T>(start[0]);
for (npy_intp i = 0; i < num; i++) {
- T k = KEY_OF(start[i]);
+ UT k = KEY_OF<T>(start[i]);
- for (size_t l = 0; l < sizeof(T); l++) {
+ for (size_t l = 0; l < sizeof(UT); l++) {
cnt[l][nth_byte(k, l)]++;
}
}
size_t ncols = 0;
- npy_ubyte cols[sizeof(T)];
- for (size_t l = 0; l < sizeof(T); l++) {
+ npy_ubyte cols[sizeof(UT)];
+ for (size_t l = 0; l < sizeof(UT); l++) {
if (cnt[l][nth_byte(key0, l)] != num) {
cols[ncols++] = l;
}
@@ -173,7 +174,7 @@ aradixsort0(T *start, npy_intp *aux, npy_intp *tosort, npy_intp num)
for (size_t l = 0; l < ncols; l++) {
npy_intp *temp;
for (npy_intp i = 0; i < num; i++) {
- T k = KEY_OF(start[tosort[i]]);
+ UT k = KEY_OF<T>(start[tosort[i]]);
npy_intp dst = cnt[cols[l]][nth_byte(k, cols[l])]++;
aux[dst] = tosort[i];
}
@@ -186,22 +187,22 @@ aradixsort0(T *start, npy_intp *aux, npy_intp *tosort, npy_intp num)
return tosort;
}
-template <class T>
+template <class T, class UT>
static int
-aradixsort_(T *start, npy_intp *tosort, npy_intp num)
+aradixsort_(UT *start, npy_intp *tosort, npy_intp num)
{
npy_intp *sorted;
npy_intp *aux;
- T k1, k2;
+ UT k1, k2;
npy_bool all_sorted = 1;
if (num < 2) {
return 0;
}
- k1 = KEY_OF(start[tosort[0]]);
+ k1 = KEY_OF<T>(start[tosort[0]]);
for (npy_intp i = 1; i < num; i++) {
- k2 = KEY_OF(start[tosort[i]]);
+ k2 = KEY_OF<T>(start[tosort[i]]);
if (k1 > k2) {
all_sorted = 0;
break;
@@ -218,7 +219,7 @@ aradixsort_(T *start, npy_intp *tosort, npy_intp num)
return -NPY_ENOMEM;
}
- sorted = aradixsort0(start, aux, tosort, num);
+ sorted = aradixsort0<T>(start, aux, tosort, num);
if (sorted != tosort) {
memcpy(tosort, sorted, num * sizeof(npy_intp));
}
@@ -231,7 +232,8 @@ template <class T>
static int
aradixsort(void *start, npy_intp *tosort, npy_intp num)
{
- return aradixsort_((T *)start, tosort, num);
+ using UT = typename std::make_unsigned<T>::type;
+ return aradixsort_<T>((UT *)start, tosort, num);
}
extern "C" {
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index bf95c999a..c8f7f4269 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3422,6 +3422,10 @@ class TestMaskedArrayMethods:
assert_equal(sortedx._data, [1, 2, -2, -1, 0])
assert_equal(sortedx._mask, [1, 1, 0, 0, 0])
+ x = array([0, -1], dtype=np.int8)
+ sortedx = sort(x, kind="stable")
+ assert_equal(sortedx, array([-1, 0], dtype=np.int8))
+
def test_stable_sort(self):
x = array([1, 2, 3, 1, 2, 3], dtype=np.uint8)
expected = array([0, 3, 1, 4, 2, 5])