diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-05-05 22:39:13 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-05-05 22:39:13 -0400 |
commit | eecb2e3c07f29c0ac991d364a846a2f8293a432a (patch) | |
tree | 0f87e1b9da37ce65f2c7aa4366f10bbbca82e558 | |
parent | a29c296f0bf306a8824705567cbb4271f383f0ea (diff) | |
parent | d1153b4cba5510bd1961e3a7ae7f0c07c124a202 (diff) | |
download | numpy-eecb2e3c07f29c0ac991d364a846a2f8293a432a.tar.gz |
Merge pull request #5748 from mhvk/numpyufunc-binop-fix
BUG Numpy ufunc binop fix for subclasses with __numpy_ufunc__ (closes #4815)
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 18 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.c | 81 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 47 |
3 files changed, 103 insertions, 43 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 6e48ef381..1a5c30832 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1297,7 +1297,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) switch (cmp_op) { case Py_LT: - if (needs_right_binop_forward(obj_self, other, "__gt__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__gt__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { /* See discussion in number.c */ Py_INCREF(Py_NotImplemented); return Py_NotImplemented; @@ -1306,7 +1307,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) n_ops.less); break; case Py_LE: - if (needs_right_binop_forward(obj_self, other, "__ge__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__ge__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1322,7 +1324,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } - if (needs_right_binop_forward(obj_self, other, "__eq__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__eq__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1397,7 +1400,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } - if (needs_right_binop_forward(obj_self, other, "__ne__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__ne__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1459,7 +1463,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) } break; case Py_GT: - if (needs_right_binop_forward(obj_self, other, "__lt__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__lt__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } @@ -1467,7 +1472,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) n_ops.greater); break; case Py_GE: - if (needs_right_binop_forward(obj_self, other, "__le__", 0)) { + if (needs_right_binop_forward(obj_self, other, "__le__", 0) && + Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 9c6e43c69..168799f11 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -111,6 +111,13 @@ has_ufunc_attr(PyObject * obj) { * [occurs if the other object is a strict subclass provided * the operation is not in-place] * + * An additional check is made in GIVE_UP_IF_HAS_RIGHT_BINOP macro below: + * + * (iv) other.__class__.__r*__ is not self.__class__.__r*__ + * + * This is needed, because CPython does not call __rmul__ if + * the tp_number slots of the two objects are the same. + * * This always prioritizes the __r*__ routines over __numpy_ufunc__, independent * of whether the other object is an ndarray subclass or not. */ @@ -146,13 +153,19 @@ needs_right_binop_forward(PyObject *self, PyObject *other, } } -#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace) \ - do { \ - if (needs_right_binop_forward((PyObject *)m1, m2, right_name, \ - inplace)) { \ - Py_INCREF(Py_NotImplemented); \ - return Py_NotImplemented; \ - } \ +/* In pure-Python, SAME_SLOTS can be replaced by + getattr(m1, op_name) is getattr(m2, op_name) */ +#define SAME_SLOTS(m1, m2, slot_name) \ + (Py_TYPE(m1)->tp_as_number != NULL && Py_TYPE(m2)->tp_as_number != NULL && \ + Py_TYPE(m1)->tp_as_number->slot_name == Py_TYPE(m2)->tp_as_number->slot_name) + +#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace, slot_name) \ + do { \ + if (needs_right_binop_forward((PyObject *)m1, m2, right_name, inplace) && \ + (inplace || !SAME_SLOTS(m1, m2, slot_name))) { \ + Py_INCREF(Py_NotImplemented); \ + return Py_NotImplemented; \ + } \ } while (0) @@ -339,21 +352,21 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op) static PyObject * array_add(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0, nb_add); return PyArray_GenericBinaryFunction(m1, m2, n_ops.add); } static PyObject * array_subtract(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0, nb_subtract); return PyArray_GenericBinaryFunction(m1, m2, n_ops.subtract); } static PyObject * array_multiply(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0, nb_multiply); return PyArray_GenericBinaryFunction(m1, m2, n_ops.multiply); } @@ -361,7 +374,7 @@ array_multiply(PyArrayObject *m1, PyObject *m2) static PyObject * array_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0, nb_divide); return PyArray_GenericBinaryFunction(m1, m2, n_ops.divide); } #endif @@ -369,7 +382,7 @@ array_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_remainder(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0, nb_remainder); return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder); } @@ -533,7 +546,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)) { /* modulo is ignored! */ PyObject *value; - GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0, nb_power); value = fast_scalar_power(a1, o2, 0); if (!value) { value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); @@ -563,56 +576,56 @@ array_invert(PyArrayObject *m1) static PyObject * array_left_shift(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0, nb_lshift); return PyArray_GenericBinaryFunction(m1, m2, n_ops.left_shift); } static PyObject * array_right_shift(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0, nb_rshift); return PyArray_GenericBinaryFunction(m1, m2, n_ops.right_shift); } static PyObject * array_bitwise_and(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0, nb_and); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_and); } static PyObject * array_bitwise_or(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0, nb_or); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_or); } static PyObject * array_bitwise_xor(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0, nb_xor); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_xor); } static PyObject * array_inplace_add(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1, nb_inplace_add); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.add); } static PyObject * array_inplace_subtract(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1, nb_inplace_subtract); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.subtract); } static PyObject * array_inplace_multiply(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1, nb_inplace_multiply); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply); } @@ -620,7 +633,7 @@ array_inplace_multiply(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1, nb_inplace_divide); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.divide); } #endif @@ -628,7 +641,7 @@ array_inplace_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_remainder(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1, nb_inplace_remainder); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder); } @@ -637,7 +650,7 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo { /* modulo is ignored! */ PyObject *value; - GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1, nb_inplace_power); value = fast_scalar_power(a1, o2, 1); if (!value) { value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power); @@ -648,56 +661,56 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo static PyObject * array_inplace_left_shift(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1, nb_inplace_lshift); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.left_shift); } static PyObject * array_inplace_right_shift(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1, nb_inplace_rshift); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.right_shift); } static PyObject * array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1, nb_inplace_and); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_and); } static PyObject * array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1, nb_inplace_or); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_or); } static PyObject * array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1, nb_inplace_xor); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_xor); } static PyObject * array_floor_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0, nb_floor_divide); return PyArray_GenericBinaryFunction(m1, m2, n_ops.floor_divide); } static PyObject * array_true_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0, nb_true_divide); return PyArray_GenericBinaryFunction(m1, m2, n_ops.true_divide); } static PyObject * array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1, nb_inplace_floor_divide); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.floor_divide); } @@ -705,7 +718,7 @@ array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_true_divide(PyArrayObject *m1, PyObject *m2) { - GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1); + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1, nb_inplace_true_divide); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.true_divide); } @@ -737,7 +750,7 @@ static PyObject * array_divmod(PyArrayObject *op1, PyObject *op2) { PyObject *divp, *modp, *result; - GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0); + GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0, nb_divmod); divp = array_floor_divide(op1, op2); if (divp == NULL) { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index dfa7880f9..33b75d490 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2321,7 +2321,6 @@ class TestBinop(object): def test_ufunc_override_rop_simple(self): # Check parts of the binary op overriding behavior in an # explicit test case that is easier to understand. - class SomeClass(object): def __numpy_ufunc__(self, *a, **kw): return "ufunc" @@ -2329,6 +2328,8 @@ class TestBinop(object): return 123 def __rmul__(self, other): return 321 + def __rsub__(self, other): + return "no subs for me" def __gt__(self, other): return "yep" def __lt__(self, other): @@ -2336,7 +2337,7 @@ class TestBinop(object): class SomeClass2(SomeClass, ndarray): def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw): - if ufunc is np.multiply: + if ufunc is np.multiply or ufunc is np.bitwise_and: return "ufunc" else: inputs = list(inputs) @@ -2346,28 +2347,47 @@ class TestBinop(object): if 'out' in kw: return r else: - x = SomeClass2(r.shape, dtype=r.dtype) + x = self.__class__(r.shape, dtype=r.dtype) x[...] = r return x + class SomeClass3(SomeClass2): + def __rsub__(self, other): + return "sub for me" + arr = np.array([0]) obj = SomeClass() obj2 = SomeClass2((1,), dtype=np.int_) obj2[0] = 9 + obj3 = SomeClass3((1,), dtype=np.int_) + obj3[0] = 4 + # obj is first, so should get to define outcome. assert_equal(obj * arr, 123) + # obj is second, but has __numpy_ufunc__ and defines __rmul__. assert_equal(arr * obj, 321) + # obj is second, but has __numpy_ufunc__ and defines __rsub__. + assert_equal(arr - obj, "no subs for me") + # obj is second, but has __numpy_ufunc__ and defines __lt__. assert_equal(arr > obj, "nope") + # obj is second, but has __numpy_ufunc__ and defines __gt__. assert_equal(arr < obj, "yep") + # Called as a ufunc, obj.__numpy_ufunc__ is used. assert_equal(np.multiply(arr, obj), "ufunc") + # obj is second, but has __numpy_ufunc__ and defines __rmul__. arr *= obj assert_equal(arr, 321) + # obj2 is an ndarray subclass, so CPython takes care of the same rules. assert_equal(obj2 * arr, 123) assert_equal(arr * obj2, 321) + assert_equal(arr - obj2, "no subs for me") assert_equal(arr > obj2, "nope") assert_equal(arr < obj2, "yep") + # Called as a ufunc, obj2.__numpy_ufunc__ is called. assert_equal(np.multiply(arr, obj2), "ufunc") + # Also when the method is not overridden. + assert_equal(arr & obj2, "ufunc") arr *= obj2 assert_equal(arr, 321) @@ -2376,6 +2396,27 @@ class TestBinop(object): assert_equal(obj2.sum(), 42) assert_(isinstance(obj2, SomeClass2)) + # Obj3 is subclass that defines __rsub__. CPython calls it. + assert_equal(arr - obj3, "sub for me") + assert_equal(obj2 - obj3, "sub for me") + # obj3 is a subclass that defines __rmul__. CPython calls it. + assert_equal(arr * obj3, 321) + # But not here, since obj3.__rmul__ is obj2.__rmul__. + assert_equal(obj2 * obj3, 123) + # And of course, here obj3.__mul__ should be called. + assert_equal(obj3 * obj2, 123) + # obj3 defines __numpy_ufunc__ but obj3.__radd__ is obj2.__radd__. + # (and both are just ndarray.__radd__); see #4815. + res = obj2 + obj3 + assert_equal(res, 46) + assert_(isinstance(res, SomeClass2)) + # Since obj3 is a subclass, it should have precedence, like CPython + # would give, even though obj2 has __numpy_ufunc__ and __radd__. + # See gh-4815 and gh-5747. + res = obj3 + obj2 + assert_equal(res, 46) + assert_(isinstance(res, SomeClass3)) + def test_ufunc_override_normalize_signature(self): # gh-5674 class SomeClass(object): |