summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arrayobject.c84
-rw-r--r--numpy/core/src/multiarray/na_object.c11
-rw-r--r--numpy/core/tests/test_na.py53
3 files changed, 85 insertions, 63 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 90759dc90..f79bf1737 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1019,7 +1019,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
int val;
/* Cast arrays to a common type */
- if (PyArray_DESCR(self)->type_num != PyArray_DESCR(other)->type_num) {
+ if (PyArray_TYPE(self) != PyArray_DESCR(other)->type_num) {
#if defined(NPY_PY3K)
/*
* Comparison between Bytes and Unicode is not defined in Py3K;
@@ -1029,7 +1029,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
return Py_NotImplemented;
#else
PyObject *new;
- if (PyArray_DESCR(self)->type_num == PyArray_STRING &&
+ if (PyArray_TYPE(self) == PyArray_STRING &&
PyArray_DESCR(other)->type_num == PyArray_UNICODE) {
PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(other));
unicode->elsize = PyArray_DESCR(self)->elsize << 2;
@@ -1041,7 +1041,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
Py_INCREF(other);
self = (PyArrayObject *)new;
}
- else if (PyArray_DESCR(self)->type_num == PyArray_UNICODE &&
+ else if (PyArray_TYPE(self) == PyArray_UNICODE &&
PyArray_DESCR(other)->type_num == PyArray_STRING) {
PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self));
unicode->elsize = PyArray_DESCR(other)->elsize << 2;
@@ -1084,7 +1084,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
goto finish;
}
- if (PyArray_DESCR(self)->type_num == NPY_UNICODE) {
+ if (PyArray_TYPE(self) == NPY_UNICODE) {
val = _compare_strings(result, mit, cmp_op, _myunincmp, rstrip);
}
else {
@@ -1224,7 +1224,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
PyArrayObject *array_other;
PyObject *result = NULL;
- int typenum;
+ PyArray_Descr *dtype = NULL;
switch (cmp_op) {
case Py_LT:
@@ -1241,33 +1241,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_False;
}
/* Make sure 'other' is an array */
- if (PyArray_Check(other)) {
- Py_INCREF(other);
- array_other = (PyArrayObject *)other;
+ if (PyArray_TYPE(self) == NPY_OBJECT) {
+ dtype = PyArray_DTYPE(self);
+ Py_INCREF(dtype);
}
- else {
- typenum = PyArray_DESCR(self)->type_num;
- if (typenum != PyArray_OBJECT) {
- typenum = PyArray_NOTYPE;
- }
- array_other = (PyArrayObject *)PyArray_FromObject(other,
- typenum, 0, 0);
- /*
- * If not successful, indicate that the items cannot be compared
- * this way.
- */
- if (array_other == NULL) {
- Py_XDECREF(array_other);
- PyErr_Clear();
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
+ array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0,
+ NPY_ARRAY_ALLOWNA, NULL);
+ /*
+ * If not successful, indicate that the items cannot be compared
+ * this way.
+ */
+ if (array_other == NULL) {
+ PyErr_Clear();
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
}
+
result = PyArray_GenericBinaryFunction(self,
(PyObject *)array_other,
n_ops.equal);
if ((result == Py_NotImplemented) &&
- (PyArray_DESCR(self)->type_num == PyArray_VOID)) {
+ (PyArray_TYPE(self) == NPY_VOID)) {
int _res;
_res = PyObject_RichCompareBool
@@ -1304,32 +1298,26 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_True;
}
/* Make sure 'other' is an array */
- if (PyArray_Check(other)) {
- Py_INCREF(other);
- array_other = (PyArrayObject *)other;
+ if (PyArray_TYPE(self) == NPY_OBJECT) {
+ dtype = PyArray_DTYPE(self);
+ Py_INCREF(dtype);
}
- else {
- typenum = PyArray_DESCR(self)->type_num;
- if (typenum != PyArray_OBJECT) {
- typenum = PyArray_NOTYPE;
- }
- array_other = (PyArrayObject *)PyArray_FromObject(other,
- typenum, 0, 0);
- /*
- * If not successful, then objects cannot be
- * compared this way
- */
- if (array_other == NULL || (PyObject *)array_other == Py_None) {
- Py_XDECREF(array_other);
- PyErr_Clear();
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
+ array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0,
+ NPY_ARRAY_ALLOWNA, NULL);
+ /*
+ * If not successful, indicate that the items cannot be compared
+ * this way.
+ */
+ if (array_other == NULL) {
+ PyErr_Clear();
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
}
+
result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other,
n_ops.not_equal);
if ((result == Py_NotImplemented) &&
- (PyArray_DESCR(self)->type_num == PyArray_VOID)) {
+ (PyArray_TYPE(self) == NPY_VOID)) {
int _res;
_res = PyObject_RichCompareBool(
@@ -1370,7 +1358,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
}
if (result == Py_NotImplemented) {
/* Try to handle string comparisons */
- if (PyArray_DESCR(self)->type_num == PyArray_OBJECT) {
+ if (PyArray_TYPE(self) == PyArray_OBJECT) {
return result;
}
array_other = (PyArrayObject *)PyArray_FromObject(other,
diff --git a/numpy/core/src/multiarray/na_object.c b/numpy/core/src/multiarray/na_object.c
index c356fb421..ed319a348 100644
--- a/numpy/core/src/multiarray/na_object.c
+++ b/numpy/core/src/multiarray/na_object.c
@@ -155,8 +155,15 @@ na_richcompare(NpyNA_fields *self, PyObject *other, int cmp_op)
}
/* Otherwise always return the NA singleton */
else {
- Py_INCREF(Npy_NA);
- return Npy_NA;
+ PyArray_Descr *bool_dtype;
+ NpyNA *ret;
+ bool_dtype = PyArray_DescrFromType(NPY_BOOL);
+ if (bool_dtype == NULL) {
+ return NULL;
+ }
+ ret = NpyNA_FromDTypeAndPayload(bool_dtype, 0, 0);
+ Py_DECREF(bool_dtype);
+ return (PyObject *)ret;
}
}
diff --git a/numpy/core/tests/test_na.py b/numpy/core/tests/test_na.py
index 188f67c25..2950598f8 100644
--- a/numpy/core/tests/test_na.py
+++ b/numpy/core/tests/test_na.py
@@ -78,21 +78,48 @@ def test_na_comparison():
# NA cannot be converted to a boolean
assert_raises(ValueError, bool, np.NA)
- # Comparison with different objects produces the singleton NA
- assert_((np.NA < 3) is np.NA)
- assert_((np.NA <= 3) is np.NA)
- assert_((np.NA == 3) is np.NA)
- assert_((np.NA != 3) is np.NA)
- assert_((np.NA >= 3) is np.NA)
- assert_((np.NA > 3) is np.NA)
+ # Comparison results should be np.NA(dtype='bool')
+ def check_comparison_result(res):
+ assert_(np.isna(res))
+ assert_(res.dtype == np.dtype('bool'))
+
+ # Comparison with different objects produces an NA with boolean type
+ check_comparison_result(np.NA < 3)
+ check_comparison_result(np.NA <= 3)
+ check_comparison_result(np.NA == 3)
+ check_comparison_result(np.NA != 3)
+ check_comparison_result(np.NA >= 3)
+ check_comparison_result(np.NA > 3)
# Should work with NA on the other side too
- assert_((3 < np.NA) is np.NA)
- assert_((3 <= np.NA) is np.NA)
- assert_((3 == np.NA) is np.NA)
- assert_((3 != np.NA) is np.NA)
- assert_((3 >= np.NA) is np.NA)
- assert_((3 > np.NA) is np.NA)
+ check_comparison_result(3 < np.NA)
+ check_comparison_result(3 <= np.NA)
+ check_comparison_result(3 == np.NA)
+ check_comparison_result(3 != np.NA)
+ check_comparison_result(3 >= np.NA)
+ check_comparison_result(3 > np.NA)
+
+ # Comparison with an array should produce an array
+ a = np.array([0,1,2]) < np.NA
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
+ a = np.array([0,1,2]) == np.NA
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
+ a = np.array([0,1,2]) != np.NA
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
+
+ # Comparison with an array should work on the other side too
+ a = np.NA > np.array([0,1,2])
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
+ a = np.NA == np.array([0,1,2])
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
+ a = np.NA != np.array([0,1,2])
+ assert_equal(np.isna(a), [1,1,1])
+ assert_equal(a.dtype, np.dtype('bool'))
def test_na_operations():
# The minimum of the payload is taken