summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2009-08-02 03:33:03 +0000
committerCharles Harris <charlesr.harris@gmail.com>2009-08-02 03:33:03 +0000
commit6d91277b97e2dd0fbad849ae38b9ccc26aa17eed (patch)
treeb0ad4c5388bc2cf274ad718397edb03fc76c4de3
parent15f9f6911af9b3fd9a28e09d344d40bb721e60b9 (diff)
downloadnumpy-6d91277b97e2dd0fbad849ae38b9ccc26aa17eed.tar.gz
Change ndarray type comparison to reflect the sort order with nans.
Searchsorted should now work with arrays containing nans.
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src99
-rw-r--r--numpy/core/tests/test_multiarray.py29
2 files changed, 108 insertions, 20 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index b865e0832..88d04f099 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -1749,42 +1749,103 @@ VOID_nonzero (char *ip, PyArrayObject *ap)
/****************** compare **********************************/
+/**begin repeat
+ * #TYPE = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG#
+ * #type = byte, ubyte, short, ushort, int, uint, long, ulong,
+ * longlong, ulonglong#
+ */
+
static int
-BOOL_compare(Bool *ip1, Bool *ip2, PyArrayObject *NPY_UNUSED(ap))
+@TYPE@_compare (@type@ *pa, @type@ *pb, PyArrayObject *NPY_UNUSED(ap))
{
- return (*ip1 ? (*ip2 ? 0 : 1) : (*ip2 ? -1 : 0));
+ const @type@ a = *pa;
+ const @type@ b = *pb;
+
+ return a < b ? -1 : a == b ? 0 : 1;
}
+/**end repeat**/
+
+
/**begin repeat
-#fname=BYTE,UBYTE,SHORT,USHORT,INT,UINT,LONG,ULONG,LONGLONG,ULONGLONG,FLOAT,DOUBLE,LONGDOUBLE#
-#type=byte, ubyte, short, ushort, int, uint, long, ulong, longlong, ulonglong, float, double, longdouble#
-*/
+ * #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
+ * #type = float, double, longdouble#
+ */
static int
-@fname@_compare (@type@ *ip1, @type@ *ip2, PyArrayObject *NPY_UNUSED(ap))
+@TYPE@_compare(@type@ *pa, @type@ *pb)
{
- return *ip1 < *ip2 ? -1 : *ip1 == *ip2 ? 0 : 1;
-}
+ const @type@ a = *pa;
+ const @type@ b = *pb;
+ int ret;
-/**end repeat**/
+ if (a < b || (b != b && a == a)) {
+ ret = -1;
+ }
+ else if (a > b || (a != a && b == b)) {
+ ret = 1;
+ }
+ else {
+ ret = 0;
+ }
+ return ret;
+}
-/* compare imaginary part first, then complex if equal imaginary */
-/**begin repeat
-#fname=CFLOAT, CDOUBLE, CLONGDOUBLE#
-#type= float, double, longdouble#
-*/
static int
-@fname@_compare (@type@ *ip1, @type@ *ip2, PyArrayObject *NPY_UNUSED(ap))
+C@TYPE@_compare(@type@ *pa, @type@ *pb)
{
- if (*ip1 == *ip2) {
- return ip1[1]<ip2[1] ? -1 : (ip1[1] == ip2[1] ? 0 : 1);
+ const @type@ ar = pa[0];
+ const @type@ ai = pa[1];
+ const @type@ br = pb[0];
+ const @type@ bi = pb[1];
+ int ret;
+
+ if (ar < br) {
+ if (ai == ai || bi != bi) {
+ ret = -1;
+ }
+ else {
+ ret = 1;
+ }
+ }
+ else if (ar > br) {
+ if (bi != bi && ai == ai) {
+ ret = -1;
+ }
+ else {
+ ret = 1;
+ }
+ }
+ else if (ar == br || (ar != ar && br != br)) {
+ if (ai < bi || (bi != bi && ai == ai)) {
+ ret = -1;
+ }
+ else if (ai > bi || (ai != ai && bi == bi)) {
+ ret = 1;
+ }
+ else {
+ ret = 0;
+ }
+ }
+ else if (ar == ar) {
+ ret = -1;
}
else {
- return *ip1 < *ip2 ? -1 : 1;
+ ret = 1;
}
+
+ return ret;
+}
+/**end repeat**/
+
+
+static int
+BOOL_compare(Bool *ip1, Bool *ip2, PyArrayObject *NPY_UNUSED(ap))
+{
+ return (*ip1 ? (*ip2 ? 0 : 1) : (*ip2 ? -1 : 0));
}
- /**end repeat**/
static int
OBJECT_compare(PyObject **ip1, PyObject **ip2, PyArrayObject *NPY_UNUSED(ap))
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ab221af4a..e5372dab8 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -293,7 +293,7 @@ class TestMethods(TestCase):
# check complex
msg = "Test complex sort order with nans"
a = np.zeros(9, dtype=np.complex128)
- a.real += [np.nan, np.nan,np. nan, 1, 0, 1, 1, 0, 0]
+ a.real += [np.nan, np.nan, np.nan, 1, 0, 1, 1, 0, 0]
a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0]
b = sort(a)
assert_equal(b, a[::-1], msg)
@@ -485,6 +485,33 @@ class TestMethods(TestCase):
a = np.array(['aaaaaaaaa' for i in range(100)], dtype=np.unicode)
assert_equal(a.argsort(kind='m'), r)
+ def test_searchsorted(self):
+ # test for floats and complex containing nans. The logic is the
+ # same for all float types so only test double types for now.
+ # The search sorted routines use the compare functions for the
+ # array type, so this checks if that is consistent with the sort
+ # order.
+
+ # check double
+ a = np.array([np.nan, 1, 0])
+ a = np.array([0, 1, np.nan])
+ msg = "Test real searchsorted with nans, side='l'"
+ b = a.searchsorted(a, side='l')
+ assert_equal(b, np.arange(3), msg)
+ msg = "Test real searchsorted with nans, side='r'"
+ b = a.searchsorted(a, side='r')
+ assert_equal(b, np.arange(1,4), msg)
+ # check double complex
+ a = np.zeros(9, dtype=np.complex128)
+ a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan]
+ a.imag += [0, 1, 0, 1, np.nan, np.nan, 0, 1, np.nan]
+ msg = "Test complex searchsorted with nans, side='l'"
+ b = a.searchsorted(a, side='l')
+ assert_equal(b, np.arange(9), msg)
+ msg = "Test complex searchsorted with nans, side='r'"
+ b = a.searchsorted(a, side='r')
+ assert_equal(b, np.arange(1,10), msg)
+
def test_flatten(self):
x0 = np.array([[1,2,3],[4,5,6]], np.int32)
x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32)