diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-07-10 16:49:59 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2014-05-04 18:14:11 +0200 |
commit | 84831ca7b7926bf1c73e1702201e7591c55588a3 (patch) | |
tree | b9ed73a6c93178ddd7826ae82bd550e9891560dc /numpy/core | |
parent | 19feb8849392a11fad6474f8bab5cd28e50952f1 (diff) | |
download | numpy-84831ca7b7926bf1c73e1702201e7591c55588a3.tar.gz |
FIX: Make object comparison ufuncs not include identity check
The issue here is that PyObject_RichCompareBool does an identity
check before doing any other checks. This is good, and is the
way for example list comparisons are handled. However in numpy
comparisons are elementwise, so that the identitycheck should
not be expected. The example for this is as follows:
obj_array = np.arange(3)
a = np.array([obj_array, 0], dtype=object)
np.equal(a, a)
If an identity check is done, this returns [True, True]. But
obj_array == obj_array itself cannot be cast to a bool.
While this is slightly slower, not doing the identity check seems
more logical in the light of elementwise operations.
Closes gh-2117.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 11 |
2 files changed, 21 insertions, 1 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index ee7e7652d..c1e2dc1c5 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -2568,9 +2568,18 @@ OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUS BINARY_LOOP { PyObject *in1 = *(PyObject **)ip1; PyObject *in2 = *(PyObject **)ip2; - int ret = PyObject_RichCompareBool( + /* + * Do not use RichCompareBool because it includes an identity check. + * But this is wrong for elementwise behaviour, since i.e. it says + * that NaN can be equal to NaN, and an array is equal to itself. + */ + PyObject *ret_obj = PyObject_RichCompare( in1 ? in1 : Py_None, in2 ? in2 : Py_None, Py_@OP@); + if (ret_obj == NULL) { + return; + } + int ret = PyObject_IsTrue(ret_obj); if (ret == -1) { return; } diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 3646fd2a9..b7b2f6308 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1534,5 +1534,16 @@ def test_complex_nan_comparisons(): assert_equal(x == y, False, err_msg="%r == %r" % (x, y)) +def test_object_array_comparison(): + obj_array = np.arange(3) + a = np.array([obj_array, 1]) + # Should raise an error because (obj_array == obj_array) is not a bool. + # At this time, a == a would return False because of the error. + assert_raises(ValueError, np.equal, a, a) + + # Check the same for NaN, when it is the *same* NaN. + a = np.array([np.nan]) + assert_equal(a == a, [False]) + if __name__ == "__main__": run_module_suite() |