summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2013-07-10 16:49:59 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2014-05-04 18:14:11 +0200
commit84831ca7b7926bf1c73e1702201e7591c55588a3 (patch)
treeb9ed73a6c93178ddd7826ae82bd550e9891560dc /numpy/core
parent19feb8849392a11fad6474f8bab5cd28e50952f1 (diff)
downloadnumpy-84831ca7b7926bf1c73e1702201e7591c55588a3.tar.gz
FIX: Make object comparison ufuncs not include identity check
The issue here is that PyObject_RichCompareBool does an identity check before doing any other checks. This is good, and is the way for example list comparisons are handled. However in numpy comparisons are elementwise, so that the identitycheck should not be expected. The example for this is as follows: obj_array = np.arange(3) a = np.array([obj_array, 0], dtype=object) np.equal(a, a) If an identity check is done, this returns [True, True]. But obj_array == obj_array itself cannot be cast to a bool. While this is slightly slower, not doing the identity check seems more logical in the light of elementwise operations. Closes gh-2117.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/loops.c.src11
-rw-r--r--numpy/core/tests/test_umath.py11
2 files changed, 21 insertions, 1 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index ee7e7652d..c1e2dc1c5 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -2568,9 +2568,18 @@ OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUS
BINARY_LOOP {
PyObject *in1 = *(PyObject **)ip1;
PyObject *in2 = *(PyObject **)ip2;
- int ret = PyObject_RichCompareBool(
+ /*
+ * Do not use RichCompareBool because it includes an identity check.
+ * But this is wrong for elementwise behaviour, since i.e. it says
+ * that NaN can be equal to NaN, and an array is equal to itself.
+ */
+ PyObject *ret_obj = PyObject_RichCompare(
in1 ? in1 : Py_None,
in2 ? in2 : Py_None, Py_@OP@);
+ if (ret_obj == NULL) {
+ return;
+ }
+ int ret = PyObject_IsTrue(ret_obj);
if (ret == -1) {
return;
}
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 3646fd2a9..b7b2f6308 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1534,5 +1534,16 @@ def test_complex_nan_comparisons():
assert_equal(x == y, False, err_msg="%r == %r" % (x, y))
+def test_object_array_comparison():
+ obj_array = np.arange(3)
+ a = np.array([obj_array, 1])
+ # Should raise an error because (obj_array == obj_array) is not a bool.
+ # At this time, a == a would return False because of the error.
+ assert_raises(ValueError, np.equal, a, a)
+
+ # Check the same for NaN, when it is the *same* NaN.
+ a = np.array([np.nan])
+ assert_equal(a == a, [False])
+
if __name__ == "__main__":
run_module_suite()