summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src6
-rw-r--r--numpy/core/tests/test_regression.py10
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index a1c95995a..110bef248 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1082,7 +1082,11 @@ gentype_richcompare(PyObject *self, PyObject *other, int cmp_op)
if (arr == NULL) {
return NULL;
}
- ret = Py_TYPE(arr)->tp_richcompare(arr, other, cmp_op);
+ /*
+ * Call via PyObject_RichCompare to ensure that other.__eq__
+ * has a chance to run when necessary
+ */
+ ret = PyObject_RichCompare(arr, other, cmp_op);
Py_DECREF(arr);
return ret;
}
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index f3e34c109..9f40d7b54 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1993,6 +1993,16 @@ class TestRegression(TestCase):
assert_(not op.eq(lhs, rhs))
assert_(op.ne(lhs, rhs))
+ def test_richcompare_scalar_and_subclass(self):
+ # gh-4709
+ class Foo(np.ndarray):
+ def __eq__(self, other):
+ return "OK"
+ x = np.array([1,2,3]).view(Foo)
+ assert_equal(10 == x, "OK")
+ assert_equal(np.int32(10) == x, "OK")
+ assert_equal(np.array([10]) == x, "OK")
+
if __name__ == "__main__":
run_module_suite()