diff options
-rw-r--r-- | numpy/core/src/_sortmodule.c.src | 29 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 18 |
2 files changed, 44 insertions, 3 deletions
diff --git a/numpy/core/src/_sortmodule.c.src b/numpy/core/src/_sortmodule.c.src index 924cfc0d5..288a0cc1a 100644 --- a/numpy/core/src/_sortmodule.c.src +++ b/numpy/core/src/_sortmodule.c.src @@ -45,7 +45,6 @@ ***************************************************************************** */ - /**begin repeat * * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, @@ -60,6 +59,7 @@ NPY_INLINE static int } /**end repeat**/ + /**begin repeat * * #TYPE = FLOAT, DOUBLE, LONGDOUBLE# @@ -68,10 +68,17 @@ NPY_INLINE static int NPY_INLINE static int @TYPE@_LT(@type@ a, @type@ b) { - return a < b; + return a < b || (b != b && a == a); } /**end repeat**/ + +/* + * For inline functions SUN recommends not using a return in the then part + * of an if statement. It's a SUN compiler thing, so assign the return value + * to a variable instead. + */ + /**begin repeat * * #TYPE = CFLOAT, CDOUBLE, CLONGDOUBLE# @@ -80,10 +87,26 @@ NPY_INLINE static int NPY_INLINE static int @TYPE@_LT(@type@ a, @type@ b) { - return a.real < b.real || (a.real == b.real && a.imag < b.imag); + int ret; + + if (a.real < b.real) { + ret = a.imag == a.imag || b.imag != b.imag; + } + else if (a.real > b.real) { + ret = b.imag != b.imag && a.imag == a.imag; + } + else if (a.real == b.real || (a.real != a.real && b.real != b.real)) { + ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag); + } + else { + ret = b.real != b.real; + } + + return ret; } /**end repeat**/ + /* The PyObject functions are stubs for later use */ NPY_INLINE static int PyObject_LT(PyObject *pa, PyObject *pb) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7ee4e0c0e..ab221af4a 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -280,6 +280,24 @@ class TestMethods(TestCase): self.failUnlessRaises(ValueError, lambda: a.transpose(0,1,2)) def test_sort(self): + # test ordering for floats and complex containing nans. It is only + # necessary to check the lessthan comparison, so sorts that + # only follow the insertion sort path are sufficient. We only + # test doubles and complex doubles as the logic is the same. + + # check doubles + msg = "Test real sort order with nans" + a = np.array([np.nan, 1, 0]) + b = sort(a) + assert_equal(b, a[::-1], msg) + # check complex + msg = "Test complex sort order with nans" + a = np.zeros(9, dtype=np.complex128) + a.real += [np.nan, np.nan,np. nan, 1, 0, 1, 1, 0, 0] + a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0] + b = sort(a) + assert_equal(b, a[::-1], msg) + # all c scalar sorts use the same code with different types # so it suffices to run a quick check with one type. The number # of sorted items must be greater than ~50 to check the actual |