diff options
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 99 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 29 |
2 files changed, 108 insertions, 20 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index b865e0832..88d04f099 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -1749,42 +1749,103 @@ VOID_nonzero (char *ip, PyArrayObject *ap) /****************** compare **********************************/ +/**begin repeat + * #TYPE = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, + * LONGLONG, ULONGLONG# + * #type = byte, ubyte, short, ushort, int, uint, long, ulong, + * longlong, ulonglong# + */ + static int -BOOL_compare(Bool *ip1, Bool *ip2, PyArrayObject *NPY_UNUSED(ap)) +@TYPE@_compare (@type@ *pa, @type@ *pb, PyArrayObject *NPY_UNUSED(ap)) { - return (*ip1 ? (*ip2 ? 0 : 1) : (*ip2 ? -1 : 0)); + const @type@ a = *pa; + const @type@ b = *pb; + + return a < b ? -1 : a == b ? 0 : 1; } +/**end repeat**/ + + /**begin repeat -#fname=BYTE,UBYTE,SHORT,USHORT,INT,UINT,LONG,ULONG,LONGLONG,ULONGLONG,FLOAT,DOUBLE,LONGDOUBLE# -#type=byte, ubyte, short, ushort, int, uint, long, ulong, longlong, ulonglong, float, double, longdouble# -*/ + * #TYPE = FLOAT, DOUBLE, LONGDOUBLE# + * #type = float, double, longdouble# + */ static int -@fname@_compare (@type@ *ip1, @type@ *ip2, PyArrayObject *NPY_UNUSED(ap)) +@TYPE@_compare(@type@ *pa, @type@ *pb) { - return *ip1 < *ip2 ? -1 : *ip1 == *ip2 ? 0 : 1; -} + const @type@ a = *pa; + const @type@ b = *pb; + int ret; -/**end repeat**/ + if (a < b || (b != b && a == a)) { + ret = -1; + } + else if (a > b || (a != a && b == b)) { + ret = 1; + } + else { + ret = 0; + } + return ret; +} -/* compare imaginary part first, then complex if equal imaginary */ -/**begin repeat -#fname=CFLOAT, CDOUBLE, CLONGDOUBLE# -#type= float, double, longdouble# -*/ static int -@fname@_compare (@type@ *ip1, @type@ *ip2, PyArrayObject *NPY_UNUSED(ap)) +C@TYPE@_compare(@type@ *pa, @type@ *pb) { - if (*ip1 == *ip2) { - return ip1[1]<ip2[1] ? -1 : (ip1[1] == ip2[1] ? 0 : 1); + const @type@ ar = pa[0]; + const @type@ ai = pa[1]; + const @type@ br = pb[0]; + const @type@ bi = pb[1]; + int ret; + + if (ar < br) { + if (ai == ai || bi != bi) { + ret = -1; + } + else { + ret = 1; + } + } + else if (ar > br) { + if (bi != bi && ai == ai) { + ret = -1; + } + else { + ret = 1; + } + } + else if (ar == br || (ar != ar && br != br)) { + if (ai < bi || (bi != bi && ai == ai)) { + ret = -1; + } + else if (ai > bi || (ai != ai && bi == bi)) { + ret = 1; + } + else { + ret = 0; + } + } + else if (ar == ar) { + ret = -1; } else { - return *ip1 < *ip2 ? -1 : 1; + ret = 1; } + + return ret; +} +/**end repeat**/ + + +static int +BOOL_compare(Bool *ip1, Bool *ip2, PyArrayObject *NPY_UNUSED(ap)) +{ + return (*ip1 ? (*ip2 ? 0 : 1) : (*ip2 ? -1 : 0)); } - /**end repeat**/ static int OBJECT_compare(PyObject **ip1, PyObject **ip2, PyArrayObject *NPY_UNUSED(ap)) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ab221af4a..e5372dab8 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -293,7 +293,7 @@ class TestMethods(TestCase): # check complex msg = "Test complex sort order with nans" a = np.zeros(9, dtype=np.complex128) - a.real += [np.nan, np.nan,np. nan, 1, 0, 1, 1, 0, 0] + a.real += [np.nan, np.nan, np.nan, 1, 0, 1, 1, 0, 0] a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0] b = sort(a) assert_equal(b, a[::-1], msg) @@ -485,6 +485,33 @@ class TestMethods(TestCase): a = np.array(['aaaaaaaaa' for i in range(100)], dtype=np.unicode) assert_equal(a.argsort(kind='m'), r) + def test_searchsorted(self): + # test for floats and complex containing nans. The logic is the + # same for all float types so only test double types for now. + # The search sorted routines use the compare functions for the + # array type, so this checks if that is consistent with the sort + # order. + + # check double + a = np.array([np.nan, 1, 0]) + a = np.array([0, 1, np.nan]) + msg = "Test real searchsorted with nans, side='l'" + b = a.searchsorted(a, side='l') + assert_equal(b, np.arange(3), msg) + msg = "Test real searchsorted with nans, side='r'" + b = a.searchsorted(a, side='r') + assert_equal(b, np.arange(1,4), msg) + # check double complex + a = np.zeros(9, dtype=np.complex128) + a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan] + a.imag += [0, 1, 0, 1, np.nan, np.nan, 0, 1, np.nan] + msg = "Test complex searchsorted with nans, side='l'" + b = a.searchsorted(a, side='l') + assert_equal(b, np.arange(9), msg) + msg = "Test complex searchsorted with nans, side='r'" + b = a.searchsorted(a, side='r') + assert_equal(b, np.arange(1,10), msg) + def test_flatten(self): x0 = np.array([[1,2,3],[4,5,6]], np.int32) x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32) |