diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 11 |
2 files changed, 21 insertions, 1 deletions
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index ee7e7652d..c1e2dc1c5 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -2568,9 +2568,18 @@ OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUS BINARY_LOOP { PyObject *in1 = *(PyObject **)ip1; PyObject *in2 = *(PyObject **)ip2; - int ret = PyObject_RichCompareBool( + /* + * Do not use RichCompareBool because it includes an identity check. + * But this is wrong for elementwise behaviour, since i.e. it says + * that NaN can be equal to NaN, and an array is equal to itself. + */ + PyObject *ret_obj = PyObject_RichCompare( in1 ? in1 : Py_None, in2 ? in2 : Py_None, Py_@OP@); + if (ret_obj == NULL) { + return; + } + int ret = PyObject_IsTrue(ret_obj); if (ret == -1) { return; } diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 3646fd2a9..b7b2f6308 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1534,5 +1534,16 @@ def test_complex_nan_comparisons(): assert_equal(x == y, False, err_msg="%r == %r" % (x, y)) +def test_object_array_comparison(): + obj_array = np.arange(3) + a = np.array([obj_array, 1]) + # Should raise an error because (obj_array == obj_array) is not a bool. + # At this time, a == a would return False because of the error. + assert_raises(ValueError, np.equal, a, a) + + # Check the same for NaN, when it is the *same* NaN. + a = np.array([np.nan]) + assert_equal(a == a, [False]) + if __name__ == "__main__": run_module_suite() |