diff options
author | Pauli Virtanen <pav@iki.fi> | 2014-05-14 21:37:16 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2014-05-14 21:37:16 +0300 |
commit | da40eba435d2b39c7d251f91d446ed411d71d3f5 (patch) | |
tree | b3cb446379a6b8d959801583d0a3ab9d566b548e | |
parent | afe32d77b3540f11ff564f851f1db5afd26857b3 (diff) | |
download | numpy-da40eba435d2b39c7d251f91d446ed411d71d3f5.tar.gz |
BUG: core: allow scalar type richcompare to call ndarray-subclass methods
This makes the scalar richcmp behave exactly as
"array(self) <op> other". The difference comes in when "other" is an ndarray
subclass, in which case this change allows the subclass to override the
comparison operations via usual Python binop dispatch rules.
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 10 |
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index a1c95995a..110bef248 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1082,7 +1082,11 @@ gentype_richcompare(PyObject *self, PyObject *other, int cmp_op) if (arr == NULL) { return NULL; } - ret = Py_TYPE(arr)->tp_richcompare(arr, other, cmp_op); + /* + * Call via PyObject_RichCompare to ensure that other.__eq__ + * has a chance to run when necessary + */ + ret = PyObject_RichCompare(arr, other, cmp_op); Py_DECREF(arr); return ret; } diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index f3e34c109..9f40d7b54 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1993,6 +1993,16 @@ class TestRegression(TestCase): assert_(not op.eq(lhs, rhs)) assert_(op.ne(lhs, rhs)) + def test_richcompare_scalar_and_subclass(self): + # gh-4709 + class Foo(np.ndarray): + def __eq__(self, other): + return "OK" + x = np.array([1,2,3]).view(Foo) + assert_equal(10 == x, "OK") + assert_equal(np.int32(10) == x, "OK") + assert_equal(np.array([10]) == x, "OK") + if __name__ == "__main__": run_module_suite() |