diff options
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 10 |
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index a1c95995a..110bef248 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1082,7 +1082,11 @@ gentype_richcompare(PyObject *self, PyObject *other, int cmp_op) if (arr == NULL) { return NULL; } - ret = Py_TYPE(arr)->tp_richcompare(arr, other, cmp_op); + /* + * Call via PyObject_RichCompare to ensure that other.__eq__ + * has a chance to run when necessary + */ + ret = PyObject_RichCompare(arr, other, cmp_op); Py_DECREF(arr); return ret; } diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index f3e34c109..9f40d7b54 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1993,6 +1993,16 @@ class TestRegression(TestCase): assert_(not op.eq(lhs, rhs)) assert_(op.ne(lhs, rhs)) + def test_richcompare_scalar_and_subclass(self): + # gh-4709 + class Foo(np.ndarray): + def __eq__(self, other): + return "OK" + x = np.array([1,2,3]).view(Foo) + assert_equal(10 == x, "OK") + assert_equal(np.int32(10) == x, "OK") + assert_equal(np.array([10]) == x, "OK") + if __name__ == "__main__": run_module_suite() |