summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2014-05-14 21:37:16 +0300
committerPauli Virtanen <pav@iki.fi>2014-05-14 21:37:16 +0300
commitda40eba435d2b39c7d251f91d446ed411d71d3f5 (patch)
treeb3cb446379a6b8d959801583d0a3ab9d566b548e
parentafe32d77b3540f11ff564f851f1db5afd26857b3 (diff)
downloadnumpy-da40eba435d2b39c7d251f91d446ed411d71d3f5.tar.gz
BUG: core: allow scalar type richcompare to call ndarray-subclass methods
This makes the scalar richcmp behave exactly as "array(self) <op> other". The difference comes in when "other" is an ndarray subclass, in which case this change allows the subclass to override the comparison operations via usual Python binop dispatch rules.
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src6
-rw-r--r--numpy/core/tests/test_regression.py10
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index a1c95995a..110bef248 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1082,7 +1082,11 @@ gentype_richcompare(PyObject *self, PyObject *other, int cmp_op)
if (arr == NULL) {
return NULL;
}
- ret = Py_TYPE(arr)->tp_richcompare(arr, other, cmp_op);
+ /*
+ * Call via PyObject_RichCompare to ensure that other.__eq__
+ * has a chance to run when necessary
+ */
+ ret = PyObject_RichCompare(arr, other, cmp_op);
Py_DECREF(arr);
return ret;
}
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index f3e34c109..9f40d7b54 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1993,6 +1993,16 @@ class TestRegression(TestCase):
assert_(not op.eq(lhs, rhs))
assert_(op.ne(lhs, rhs))
+ def test_richcompare_scalar_and_subclass(self):
+ # gh-4709
+ class Foo(np.ndarray):
+ def __eq__(self, other):
+ return "OK"
+ x = np.array([1,2,3]).view(Foo)
+ assert_equal(10 == x, "OK")
+ assert_equal(np.int32(10) == x, "OK")
+ assert_equal(np.array([10]) == x, "OK")
+
if __name__ == "__main__":
run_module_suite()