summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-05-05 22:39:13 -0400
committerCharles Harris <charlesr.harris@gmail.com>2015-05-05 22:39:13 -0400
commiteecb2e3c07f29c0ac991d364a846a2f8293a432a (patch)
tree0f87e1b9da37ce65f2c7aa4366f10bbbca82e558
parenta29c296f0bf306a8824705567cbb4271f383f0ea (diff)
parentd1153b4cba5510bd1961e3a7ae7f0c07c124a202 (diff)
downloadnumpy-eecb2e3c07f29c0ac991d364a846a2f8293a432a.tar.gz
Merge pull request #5748 from mhvk/numpyufunc-binop-fix
BUG Numpy ufunc binop fix for subclasses with __numpy_ufunc__ (closes #4815)
-rw-r--r--numpy/core/src/multiarray/arrayobject.c18
-rw-r--r--numpy/core/src/multiarray/number.c81
-rw-r--r--numpy/core/tests/test_multiarray.py47
3 files changed, 103 insertions, 43 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 6e48ef381..1a5c30832 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1297,7 +1297,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op) {
case Py_LT:
- if (needs_right_binop_forward(obj_self, other, "__gt__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__gt__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
/* See discussion in number.c */
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
@@ -1306,7 +1307,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
n_ops.less);
break;
case Py_LE:
- if (needs_right_binop_forward(obj_self, other, "__ge__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__ge__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1322,7 +1324,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
- if (needs_right_binop_forward(obj_self, other, "__eq__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__eq__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1397,7 +1400,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
- if (needs_right_binop_forward(obj_self, other, "__ne__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__ne__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1459,7 +1463,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
}
break;
case Py_GT:
- if (needs_right_binop_forward(obj_self, other, "__lt__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__lt__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
@@ -1467,7 +1472,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
n_ops.greater);
break;
case Py_GE:
- if (needs_right_binop_forward(obj_self, other, "__le__", 0)) {
+ if (needs_right_binop_forward(obj_self, other, "__le__", 0) &&
+ Py_TYPE(obj_self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 9c6e43c69..168799f11 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -111,6 +111,13 @@ has_ufunc_attr(PyObject * obj) {
* [occurs if the other object is a strict subclass provided
* the operation is not in-place]
*
+ * An additional check is made in GIVE_UP_IF_HAS_RIGHT_BINOP macro below:
+ *
+ * (iv) other.__class__.__r*__ is not self.__class__.__r*__
+ *
+ * This is needed, because CPython does not call __rmul__ if
+ * the tp_number slots of the two objects are the same.
+ *
* This always prioritizes the __r*__ routines over __numpy_ufunc__, independent
* of whether the other object is an ndarray subclass or not.
*/
@@ -146,13 +153,19 @@ needs_right_binop_forward(PyObject *self, PyObject *other,
}
}
-#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace) \
- do { \
- if (needs_right_binop_forward((PyObject *)m1, m2, right_name, \
- inplace)) { \
- Py_INCREF(Py_NotImplemented); \
- return Py_NotImplemented; \
- } \
+/* In pure-Python, SAME_SLOTS can be replaced by
+ getattr(m1, op_name) is getattr(m2, op_name) */
+#define SAME_SLOTS(m1, m2, slot_name) \
+ (Py_TYPE(m1)->tp_as_number != NULL && Py_TYPE(m2)->tp_as_number != NULL && \
+ Py_TYPE(m1)->tp_as_number->slot_name == Py_TYPE(m2)->tp_as_number->slot_name)
+
+#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace, slot_name) \
+ do { \
+ if (needs_right_binop_forward((PyObject *)m1, m2, right_name, inplace) && \
+ (inplace || !SAME_SLOTS(m1, m2, slot_name))) { \
+ Py_INCREF(Py_NotImplemented); \
+ return Py_NotImplemented; \
+ } \
} while (0)
@@ -339,21 +352,21 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op)
static PyObject *
array_add(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0, nb_add);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.add);
}
static PyObject *
array_subtract(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0, nb_subtract);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.subtract);
}
static PyObject *
array_multiply(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0, nb_multiply);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.multiply);
}
@@ -361,7 +374,7 @@ array_multiply(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0, nb_divide);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.divide);
}
#endif
@@ -369,7 +382,7 @@ array_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_remainder(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0, nb_remainder);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
@@ -533,7 +546,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo))
{
/* modulo is ignored! */
PyObject *value;
- GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0, nb_power);
value = fast_scalar_power(a1, o2, 0);
if (!value) {
value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power);
@@ -563,56 +576,56 @@ array_invert(PyArrayObject *m1)
static PyObject *
array_left_shift(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0, nb_lshift);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.left_shift);
}
static PyObject *
array_right_shift(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0, nb_rshift);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.right_shift);
}
static PyObject *
array_bitwise_and(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0, nb_and);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_and);
}
static PyObject *
array_bitwise_or(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0, nb_or);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_or);
}
static PyObject *
array_bitwise_xor(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0, nb_xor);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_xor);
}
static PyObject *
array_inplace_add(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1, nb_inplace_add);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.add);
}
static PyObject *
array_inplace_subtract(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1, nb_inplace_subtract);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.subtract);
}
static PyObject *
array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1, nb_inplace_multiply);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply);
}
@@ -620,7 +633,7 @@ array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1, nb_inplace_divide);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.divide);
}
#endif
@@ -628,7 +641,7 @@ array_inplace_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1, nb_inplace_remainder);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder);
}
@@ -637,7 +650,7 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
{
/* modulo is ignored! */
PyObject *value;
- GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1, nb_inplace_power);
value = fast_scalar_power(a1, o2, 1);
if (!value) {
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
@@ -648,56 +661,56 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
static PyObject *
array_inplace_left_shift(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1, nb_inplace_lshift);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.left_shift);
}
static PyObject *
array_inplace_right_shift(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1, nb_inplace_rshift);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.right_shift);
}
static PyObject *
array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1, nb_inplace_and);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_and);
}
static PyObject *
array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1, nb_inplace_or);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_or);
}
static PyObject *
array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1, nb_inplace_xor);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_xor);
}
static PyObject *
array_floor_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0, nb_floor_divide);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.floor_divide);
}
static PyObject *
array_true_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0, nb_true_divide);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.true_divide);
}
static PyObject *
array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1, nb_inplace_floor_divide);
return PyArray_GenericInplaceBinaryFunction(m1, m2,
n_ops.floor_divide);
}
@@ -705,7 +718,7 @@ array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_true_divide(PyArrayObject *m1, PyObject *m2)
{
- GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1, nb_inplace_true_divide);
return PyArray_GenericInplaceBinaryFunction(m1, m2,
n_ops.true_divide);
}
@@ -737,7 +750,7 @@ static PyObject *
array_divmod(PyArrayObject *op1, PyObject *op2)
{
PyObject *divp, *modp, *result;
- GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0);
+ GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0, nb_divmod);
divp = array_floor_divide(op1, op2);
if (divp == NULL) {
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index dfa7880f9..33b75d490 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2321,7 +2321,6 @@ class TestBinop(object):
def test_ufunc_override_rop_simple(self):
# Check parts of the binary op overriding behavior in an
# explicit test case that is easier to understand.
-
class SomeClass(object):
def __numpy_ufunc__(self, *a, **kw):
return "ufunc"
@@ -2329,6 +2328,8 @@ class TestBinop(object):
return 123
def __rmul__(self, other):
return 321
+ def __rsub__(self, other):
+ return "no subs for me"
def __gt__(self, other):
return "yep"
def __lt__(self, other):
@@ -2336,7 +2337,7 @@ class TestBinop(object):
class SomeClass2(SomeClass, ndarray):
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
- if ufunc is np.multiply:
+ if ufunc is np.multiply or ufunc is np.bitwise_and:
return "ufunc"
else:
inputs = list(inputs)
@@ -2346,28 +2347,47 @@ class TestBinop(object):
if 'out' in kw:
return r
else:
- x = SomeClass2(r.shape, dtype=r.dtype)
+ x = self.__class__(r.shape, dtype=r.dtype)
x[...] = r
return x
+ class SomeClass3(SomeClass2):
+ def __rsub__(self, other):
+ return "sub for me"
+
arr = np.array([0])
obj = SomeClass()
obj2 = SomeClass2((1,), dtype=np.int_)
obj2[0] = 9
+ obj3 = SomeClass3((1,), dtype=np.int_)
+ obj3[0] = 4
+ # obj is first, so should get to define outcome.
assert_equal(obj * arr, 123)
+ # obj is second, but has __numpy_ufunc__ and defines __rmul__.
assert_equal(arr * obj, 321)
+ # obj is second, but has __numpy_ufunc__ and defines __rsub__.
+ assert_equal(arr - obj, "no subs for me")
+ # obj is second, but has __numpy_ufunc__ and defines __lt__.
assert_equal(arr > obj, "nope")
+ # obj is second, but has __numpy_ufunc__ and defines __gt__.
assert_equal(arr < obj, "yep")
+ # Called as a ufunc, obj.__numpy_ufunc__ is used.
assert_equal(np.multiply(arr, obj), "ufunc")
+ # obj is second, but has __numpy_ufunc__ and defines __rmul__.
arr *= obj
assert_equal(arr, 321)
+ # obj2 is an ndarray subclass, so CPython takes care of the same rules.
assert_equal(obj2 * arr, 123)
assert_equal(arr * obj2, 321)
+ assert_equal(arr - obj2, "no subs for me")
assert_equal(arr > obj2, "nope")
assert_equal(arr < obj2, "yep")
+ # Called as a ufunc, obj2.__numpy_ufunc__ is called.
assert_equal(np.multiply(arr, obj2), "ufunc")
+ # Also when the method is not overridden.
+ assert_equal(arr & obj2, "ufunc")
arr *= obj2
assert_equal(arr, 321)
@@ -2376,6 +2396,27 @@ class TestBinop(object):
assert_equal(obj2.sum(), 42)
assert_(isinstance(obj2, SomeClass2))
+ # Obj3 is subclass that defines __rsub__. CPython calls it.
+ assert_equal(arr - obj3, "sub for me")
+ assert_equal(obj2 - obj3, "sub for me")
+ # obj3 is a subclass that defines __rmul__. CPython calls it.
+ assert_equal(arr * obj3, 321)
+ # But not here, since obj3.__rmul__ is obj2.__rmul__.
+ assert_equal(obj2 * obj3, 123)
+ # And of course, here obj3.__mul__ should be called.
+ assert_equal(obj3 * obj2, 123)
+ # obj3 defines __numpy_ufunc__ but obj3.__radd__ is obj2.__radd__.
+ # (and both are just ndarray.__radd__); see #4815.
+ res = obj2 + obj3
+ assert_equal(res, 46)
+ assert_(isinstance(res, SomeClass2))
+ # Since obj3 is a subclass, it should have precedence, like CPython
+ # would give, even though obj2 has __numpy_ufunc__ and __radd__.
+ # See gh-4815 and gh-5747.
+ res = obj3 + obj2
+ assert_equal(res, 46)
+ assert_(isinstance(res, SomeClass3))
+
def test_ufunc_override_normalize_signature(self):
# gh-5674
class SomeClass(object):