diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-08-26 09:53:13 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-08-27 07:27:02 -0600 |
commit | 99774bee1e138bbe2c52a4cc6f54975279f436e3 (patch) | |
tree | f9575717b70032c679eef28701372bb63ddb9377 | |
parent | 5c7b9bb8fec7d2abbb13ad56be23afc49296df2f (diff) | |
download | numpy-99774bee1e138bbe2c52a4cc6f54975279f436e3.tar.gz |
ENH: missingdata: Make comparisons with NA return NA(dtype='bool')
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 84 | ||||
-rw-r--r-- | numpy/core/src/multiarray/na_object.c | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_na.py | 53 |
3 files changed, 85 insertions, 63 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 90759dc90..f79bf1737 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1019,7 +1019,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, int val; /* Cast arrays to a common type */ - if (PyArray_DESCR(self)->type_num != PyArray_DESCR(other)->type_num) { + if (PyArray_TYPE(self) != PyArray_DESCR(other)->type_num) { #if defined(NPY_PY3K) /* * Comparison between Bytes and Unicode is not defined in Py3K; @@ -1029,7 +1029,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, return Py_NotImplemented; #else PyObject *new; - if (PyArray_DESCR(self)->type_num == PyArray_STRING && + if (PyArray_TYPE(self) == PyArray_STRING && PyArray_DESCR(other)->type_num == PyArray_UNICODE) { PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(other)); unicode->elsize = PyArray_DESCR(self)->elsize << 2; @@ -1041,7 +1041,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, Py_INCREF(other); self = (PyArrayObject *)new; } - else if (PyArray_DESCR(self)->type_num == PyArray_UNICODE && + else if (PyArray_TYPE(self) == PyArray_UNICODE && PyArray_DESCR(other)->type_num == PyArray_STRING) { PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self)); unicode->elsize = PyArray_DESCR(other)->elsize << 2; @@ -1084,7 +1084,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, goto finish; } - if (PyArray_DESCR(self)->type_num == NPY_UNICODE) { + if (PyArray_TYPE(self) == NPY_UNICODE) { val = _compare_strings(result, mit, cmp_op, _myunincmp, rstrip); } else { @@ -1224,7 +1224,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) { PyArrayObject *array_other; PyObject *result = NULL; - int typenum; + PyArray_Descr *dtype = NULL; switch (cmp_op) { case Py_LT: @@ -1241,33 +1241,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) return Py_False; } /* Make sure 'other' is an array */ - if (PyArray_Check(other)) { - Py_INCREF(other); - array_other = (PyArrayObject *)other; + if (PyArray_TYPE(self) == NPY_OBJECT) { + dtype = PyArray_DTYPE(self); + Py_INCREF(dtype); } - else { - typenum = PyArray_DESCR(self)->type_num; - if (typenum != PyArray_OBJECT) { - typenum = PyArray_NOTYPE; - } - array_other = (PyArrayObject *)PyArray_FromObject(other, - typenum, 0, 0); - /* - * If not successful, indicate that the items cannot be compared - * this way. - */ - if (array_other == NULL) { - Py_XDECREF(array_other); - PyErr_Clear(); - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } + array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, + NPY_ARRAY_ALLOWNA, NULL); + /* + * If not successful, indicate that the items cannot be compared + * this way. + */ + if (array_other == NULL) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; } + result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other, n_ops.equal); if ((result == Py_NotImplemented) && - (PyArray_DESCR(self)->type_num == PyArray_VOID)) { + (PyArray_TYPE(self) == NPY_VOID)) { int _res; _res = PyObject_RichCompareBool @@ -1304,32 +1298,26 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) return Py_True; } /* Make sure 'other' is an array */ - if (PyArray_Check(other)) { - Py_INCREF(other); - array_other = (PyArrayObject *)other; + if (PyArray_TYPE(self) == NPY_OBJECT) { + dtype = PyArray_DTYPE(self); + Py_INCREF(dtype); } - else { - typenum = PyArray_DESCR(self)->type_num; - if (typenum != PyArray_OBJECT) { - typenum = PyArray_NOTYPE; - } - array_other = (PyArrayObject *)PyArray_FromObject(other, - typenum, 0, 0); - /* - * If not successful, then objects cannot be - * compared this way - */ - if (array_other == NULL || (PyObject *)array_other == Py_None) { - Py_XDECREF(array_other); - PyErr_Clear(); - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } + array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, + NPY_ARRAY_ALLOWNA, NULL); + /* + * If not successful, indicate that the items cannot be compared + * this way. + */ + if (array_other == NULL) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; } + result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other, n_ops.not_equal); if ((result == Py_NotImplemented) && - (PyArray_DESCR(self)->type_num == PyArray_VOID)) { + (PyArray_TYPE(self) == NPY_VOID)) { int _res; _res = PyObject_RichCompareBool( @@ -1370,7 +1358,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) } if (result == Py_NotImplemented) { /* Try to handle string comparisons */ - if (PyArray_DESCR(self)->type_num == PyArray_OBJECT) { + if (PyArray_TYPE(self) == PyArray_OBJECT) { return result; } array_other = (PyArrayObject *)PyArray_FromObject(other, diff --git a/numpy/core/src/multiarray/na_object.c b/numpy/core/src/multiarray/na_object.c index c356fb421..ed319a348 100644 --- a/numpy/core/src/multiarray/na_object.c +++ b/numpy/core/src/multiarray/na_object.c @@ -155,8 +155,15 @@ na_richcompare(NpyNA_fields *self, PyObject *other, int cmp_op) } /* Otherwise always return the NA singleton */ else { - Py_INCREF(Npy_NA); - return Npy_NA; + PyArray_Descr *bool_dtype; + NpyNA *ret; + bool_dtype = PyArray_DescrFromType(NPY_BOOL); + if (bool_dtype == NULL) { + return NULL; + } + ret = NpyNA_FromDTypeAndPayload(bool_dtype, 0, 0); + Py_DECREF(bool_dtype); + return (PyObject *)ret; } } diff --git a/numpy/core/tests/test_na.py b/numpy/core/tests/test_na.py index 188f67c25..2950598f8 100644 --- a/numpy/core/tests/test_na.py +++ b/numpy/core/tests/test_na.py @@ -78,21 +78,48 @@ def test_na_comparison(): # NA cannot be converted to a boolean assert_raises(ValueError, bool, np.NA) - # Comparison with different objects produces the singleton NA - assert_((np.NA < 3) is np.NA) - assert_((np.NA <= 3) is np.NA) - assert_((np.NA == 3) is np.NA) - assert_((np.NA != 3) is np.NA) - assert_((np.NA >= 3) is np.NA) - assert_((np.NA > 3) is np.NA) + # Comparison results should be np.NA(dtype='bool') + def check_comparison_result(res): + assert_(np.isna(res)) + assert_(res.dtype == np.dtype('bool')) + + # Comparison with different objects produces an NA with boolean type + check_comparison_result(np.NA < 3) + check_comparison_result(np.NA <= 3) + check_comparison_result(np.NA == 3) + check_comparison_result(np.NA != 3) + check_comparison_result(np.NA >= 3) + check_comparison_result(np.NA > 3) # Should work with NA on the other side too - assert_((3 < np.NA) is np.NA) - assert_((3 <= np.NA) is np.NA) - assert_((3 == np.NA) is np.NA) - assert_((3 != np.NA) is np.NA) - assert_((3 >= np.NA) is np.NA) - assert_((3 > np.NA) is np.NA) + check_comparison_result(3 < np.NA) + check_comparison_result(3 <= np.NA) + check_comparison_result(3 == np.NA) + check_comparison_result(3 != np.NA) + check_comparison_result(3 >= np.NA) + check_comparison_result(3 > np.NA) + + # Comparison with an array should produce an array + a = np.array([0,1,2]) < np.NA + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) + a = np.array([0,1,2]) == np.NA + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) + a = np.array([0,1,2]) != np.NA + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) + + # Comparison with an array should work on the other side too + a = np.NA > np.array([0,1,2]) + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) + a = np.NA == np.array([0,1,2]) + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) + a = np.NA != np.array([0,1,2]) + assert_equal(np.isna(a), [1,1,1]) + assert_equal(a.dtype, np.dtype('bool')) def test_na_operations(): # The minimum of the payload is taken |