summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/_sortmodule.c.src29
-rw-r--r--numpy/core/tests/test_multiarray.py18
2 files changed, 44 insertions, 3 deletions
diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src
index 924cfc0d5..288a0cc1a 100644
--- a/numpy/core/src/_sortmodule.c.src
+++ b/numpy/core/src/_sortmodule.c.src
@@ -45,7 +45,6 @@
*****************************************************************************
*/
-
/**begin repeat
*
* #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
@@ -60,6 +59,7 @@ NPY_INLINE static int
}
/**end repeat**/
+
/**begin repeat
*
* #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
@@ -68,10 +68,17 @@ NPY_INLINE static int
NPY_INLINE static int
@TYPE@_LT(@type@ a, @type@ b)
{
- return a < b;
+ return a < b || (b != b && a == a);
}
/**end repeat**/
+
+/*
+ * For inline functions SUN recommends not using a return in the then part
+ * of an if statement. It's a SUN compiler thing, so assign the return value
+ * to a variable instead.
+ */
+
/**begin repeat
*
* #TYPE = CFLOAT, CDOUBLE, CLONGDOUBLE#
@@ -80,10 +87,26 @@ NPY_INLINE static int
NPY_INLINE static int
@TYPE@_LT(@type@ a, @type@ b)
{
- return a.real < b.real || (a.real == b.real && a.imag < b.imag);
+ int ret;
+
+ if (a.real < b.real) {
+ ret = a.imag == a.imag || b.imag != b.imag;
+ }
+ else if (a.real > b.real) {
+ ret = b.imag != b.imag && a.imag == a.imag;
+ }
+ else if (a.real == b.real || (a.real != a.real && b.real != b.real)) {
+ ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag);
+ }
+ else {
+ ret = b.real != b.real;
+ }
+
+ return ret;
}
/**end repeat**/
+
/* The PyObject functions are stubs for later use */
NPY_INLINE static int
PyObject_LT(PyObject *pa, PyObject *pb)
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7ee4e0c0e..ab221af4a 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -280,6 +280,24 @@ class TestMethods(TestCase):
self.failUnlessRaises(ValueError, lambda: a.transpose(0,1,2))
def test_sort(self):
+ # test ordering for floats and complex containing nans. It is only
+ # necessary to check the lessthan comparison, so sorts that
+ # only follow the insertion sort path are sufficient. We only
+ # test doubles and complex doubles as the logic is the same.
+
+ # check doubles
+ msg = "Test real sort order with nans"
+ a = np.array([np.nan, 1, 0])
+ b = sort(a)
+ assert_equal(b, a[::-1], msg)
+ # check complex
+ msg = "Test complex sort order with nans"
+ a = np.zeros(9, dtype=np.complex128)
+ a.real += [np.nan, np.nan,np. nan, 1, 0, 1, 1, 0, 0]
+ a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0]
+ b = sort(a)
+ assert_equal(b, a[::-1], msg)
+
# all c scalar sorts use the same code with different types
# so it suffices to run a quick check with one type. The number
# of sorted items must be greater than ~50 to check the actual