summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-10-19 14:24:29 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-10-19 14:24:29 -0700
commit55d3ef5b81b709cc9d8ad6b13f75ca0f05f65276 (patch)
tree65cae9f774f7f9207e4b12f45c81771e6a4680f7 /numpy/core
parent1eb5f03c4c27009cac5d7c19f36bfab718533072 (diff)
parentfd5d3088ed71bbc2fe5a774178be5e0ba04e4cd1 (diff)
downloadnumpy-55d3ef5b81b709cc9d8ad6b13f75ca0f05f65276.tar.gz
Merge pull request #3856 from pv/op-before-ufunc
BUG: core: ensure __r*__ has precedence over __numpy_ufunc__
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/arrayobject.c25
-rw-r--r--numpy/core/src/multiarray/number.c85
-rw-r--r--numpy/core/src/multiarray/number.h4
-rw-r--r--numpy/core/tests/test_multiarray.py183
4 files changed, 294 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 4c48ba673..b91e84366 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1273,10 +1273,19 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op) {
case Py_LT:
+ if (needs_right_binop_forward(self, other, "__gt__", 0)) {
+ /* See discussion in number.c */
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self, other,
n_ops.less);
break;
case Py_LE:
+ if (needs_right_binop_forward(self, other, "__ge__", 0)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self, other,
n_ops.less_equal);
break;
@@ -1285,6 +1294,10 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
+ if (needs_right_binop_forward(self, other, "__eq__", 0)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self,
(PyObject *)other,
n_ops.equal);
@@ -1343,6 +1356,10 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
+ if (needs_right_binop_forward(self, other, "__ne__", 0)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
n_ops.not_equal);
if (result && result != Py_NotImplemented)
@@ -1392,10 +1409,18 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
}
break;
case Py_GT:
+ if (needs_right_binop_forward(self, other, "__lt__", 0)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self, other,
n_ops.greater);
break;
case Py_GE:
+ if (needs_right_binop_forward(self, other, "__le__", 0)) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
result = PyArray_GenericBinaryFunction(self, other,
n_ops.greater_equal);
break;
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index c1d04c536..2ad65bbe8 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -85,6 +85,62 @@ PyArray_SetNumericOps(PyObject *dict)
(PyDict_SetItemString(dict, #op, n_ops.op)==-1)) \
goto fail;
+/*
+ * Check whether the operation needs to be forwarded to the right-hand binary
+ * operation.
+ *
+ * This is the case when all of the following conditions apply:
+ *
+ * (i) the other object defines __numpy_ufunc__
+ * (ii) the other object defines the right-hand operation __r*__
+ * (iii) Python hasn't already called the right-hand operation
+ * [occurs if the other object is a strict subclass provided
+ * the operation is not in-place]
+ *
+ * This always prioritizes the __r*__ routines over __numpy_ufunc__, independent
+ * of whether the other object is an ndarray subclass or not.
+ */
+
+NPY_NO_EXPORT int
+needs_right_binop_forward(PyObject *self, PyObject *other,
+ char *right_name, int inplace_op)
+{
+ if (other == NULL ||
+ self == NULL ||
+ Py_TYPE(self) == Py_TYPE(other) ||
+ PyArray_CheckExact(other) ||
+ PyArray_CheckAnyScalar(other)) {
+ /*
+ * Quick cases
+ */
+ return 0;
+ }
+ if (!inplace_op && PyType_IsSubtype(Py_TYPE(other), Py_TYPE(self)) ||
+ !PyArray_Check(self)) {
+ /*
+ * Bail out if Python would already have called the right-hand
+ * operation.
+ */
+ return 0;
+ }
+ if (PyObject_HasAttrString(other, "__numpy_ufunc__") &&
+ PyObject_HasAttrString(other, right_name)) {
+ return 1;
+ }
+ else {
+ return 0;
+ }
+}
+
+#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace) \
+ do { \
+ if (needs_right_binop_forward(m1, m2, right_name, inplace)) { \
+ Py_INCREF(Py_NotImplemented); \
+ return Py_NotImplemented; \
+ } \
+ } while (0)
+
+
/*NUMPY_API
Get dictionary showing number functions that all arrays will use
*/
@@ -210,7 +266,7 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op)
return Py_NotImplemented;
}
- if (!PyArray_Check(m2)) {
+ if (!PyArray_Check(m2) && !PyObject_HasAttrString(m2, "__numpy_ufunc__")) {
/*
* Catch priority inversion and punt, but only if it's guaranteed
* that we were called through m1 and the other guy is not an array
@@ -268,18 +324,21 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op)
static PyObject *
array_add(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.add);
}
static PyObject *
array_subtract(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.subtract);
}
static PyObject *
array_multiply(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.multiply);
}
@@ -287,6 +346,7 @@ array_multiply(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.divide);
}
#endif
@@ -294,6 +354,7 @@ array_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_remainder(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
@@ -448,6 +509,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo))
{
/* modulo is ignored! */
PyObject *value;
+ GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0);
value = fast_scalar_power(a1, o2, 0);
if (!value) {
value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power);
@@ -477,48 +539,56 @@ array_invert(PyArrayObject *m1)
static PyObject *
array_left_shift(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.left_shift);
}
static PyObject *
array_right_shift(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.right_shift);
}
static PyObject *
array_bitwise_and(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_and);
}
static PyObject *
array_bitwise_or(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_or);
}
static PyObject *
array_bitwise_xor(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_xor);
}
static PyObject *
array_inplace_add(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.add);
}
static PyObject *
array_inplace_subtract(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.subtract);
}
static PyObject *
array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply);
}
@@ -526,6 +596,7 @@ array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.divide);
}
#endif
@@ -533,6 +604,7 @@ array_inplace_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_remainder(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder);
}
@@ -541,6 +613,7 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
{
/* modulo is ignored! */
PyObject *value;
+ GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1);
value = fast_scalar_power(a1, o2, 1);
if (!value) {
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
@@ -551,48 +624,56 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
static PyObject *
array_inplace_left_shift(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.left_shift);
}
static PyObject *
array_inplace_right_shift(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.right_shift);
}
static PyObject *
array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_and);
}
static PyObject *
array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_or);
}
static PyObject *
array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_xor);
}
static PyObject *
array_floor_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.floor_divide);
}
static PyObject *
array_true_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.true_divide);
}
static PyObject *
array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2,
n_ops.floor_divide);
}
@@ -600,6 +681,7 @@ array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_inplace_true_divide(PyArrayObject *m1, PyObject *m2)
{
+ GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1);
return PyArray_GenericInplaceBinaryFunction(m1, m2,
n_ops.true_divide);
}
@@ -631,6 +713,7 @@ static PyObject *
array_divmod(PyArrayObject *op1, PyObject *op2)
{
PyObject *divp, *modp, *result;
+ GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0);
divp = array_floor_divide(op1, op2);
if (divp == NULL) {
diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h
index 0018b7348..63ea40696 100644
--- a/numpy/core/src/multiarray/number.h
+++ b/numpy/core/src/multiarray/number.h
@@ -69,4 +69,8 @@ NPY_NO_EXPORT PyObject *
PyArray_GenericAccumulateFunction(PyArrayObject *m1, PyObject *op, int axis,
int rtype, PyArrayObject *out);
+NPY_NO_EXPORT int
+needs_right_binop_forward(PyObject *self, PyObject *other, char *right_name,
+ int is_inplace);
+
#endif
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index c41c56b9c..cb02c3666 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4,12 +4,12 @@ import tempfile
import sys
import os
import warnings
+import operator
if sys.version_info[0] >= 3:
import builtins
else:
import __builtin__ as builtins
-
import numpy as np
from nose import SkipTest
from numpy.core import *
@@ -1460,7 +1460,7 @@ class TestMethods(TestCase):
class A(object):
def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
return "A"
-
+
class B(object):
def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
return NotImplemented
@@ -1547,6 +1547,185 @@ class TestMethods(TestCase):
assert_equal(a.ravel(order='K'), [2, 3, 0, 1])
assert_(a.ravel(order='K').flags.owndata)
+
+class TestBinop(object):
+ def test_ufunc_override_rop_precedence(self):
+ # Check that __rmul__ and other right-hand operations have
+ # precedence over __numpy_ufunc__
+
+ ops = {
+ '__add__': ('__radd__', np.add, True),
+ '__sub__': ('__rsub__', np.subtract, True),
+ '__mul__': ('__rmul__', np.multiply, True),
+ '__truediv__': ('__rtruediv__', np.true_divide, True),
+ '__floordiv__': ('__rfloordiv__', np.floor_divide, True),
+ '__mod__': ('__rmod__', np.remainder, True),
+ '__divmod__': ('__rdivmod__', None, False),
+ '__pow__': ('__rpow__', np.power, True),
+ '__lshift__': ('__rlshift__', np.left_shift, True),
+ '__rshift__': ('__rrshift__', np.right_shift, True),
+ '__and__': ('__rand__', np.bitwise_and, True),
+ '__xor__': ('__rxor__', np.bitwise_xor, True),
+ '__or__': ('__ror__', np.bitwise_or, True),
+ '__ge__': ('__le__', np.less_equal, False),
+ '__gt__': ('__lt__', np.less, False),
+ '__le__': ('__ge__', np.greater_equal, False),
+ '__lt__': ('__gt__', np.greater, False),
+ '__eq__': ('__eq__', np.equal, False),
+ '__ne__': ('__ne__', np.not_equal, False),
+ }
+
+ class OtherNdarraySubclass(ndarray):
+ pass
+
+ class OtherNdarraySubclassWithOverride(ndarray):
+ def __numpy_ufunc__(self, *a, **kw):
+ raise AssertionError(("__numpy_ufunc__ %r %r shouldn't have "
+ "been called!") % (a, kw))
+
+ def check(op_name, ndsubclass):
+ rop_name, np_op, has_iop = ops[op_name]
+
+ if has_iop:
+ iop_name = '__i' + op_name[2:]
+ iop = getattr(operator, iop_name)
+
+ if op_name == "__divmod__":
+ op = divmod
+ else:
+ op = getattr(operator, op_name)
+
+ # Dummy class
+ def __init__(self, *a, **kw):
+ pass
+
+ def __numpy_ufunc__(self, *a, **kw):
+ raise AssertionError(("__numpy_ufunc__ %r %r shouldn't have "
+ "been called!") % (a, kw))
+
+ def __op__(self, *other):
+ return "op"
+
+ def __rop__(self, *other):
+ return "rop"
+
+ if ndsubclass:
+ bases = (ndarray,)
+ else:
+ bases = (object,)
+
+ dct = {'__init__': __init__,
+ '__numpy_ufunc__': __numpy_ufunc__,
+ op_name: __op__}
+ if op_name != rop_name:
+ dct[rop_name] = __rop__
+
+ cls = type("Rop" + rop_name, bases, dct)
+
+ # Check behavior against both bare ndarray objects and a
+ # ndarray subclasses with and without their own override
+ obj = cls((1,))
+
+ arr_objs = [np.array([1]),
+ np.array([2]).view(OtherNdarraySubclass),
+ np.array([3]).view(OtherNdarraySubclassWithOverride),
+ ]
+
+ for arr in arr_objs:
+ err_msg = "%r %r" % (op_name, arr,)
+
+ # Check that ndarray op gives up if it sees a non-subclass
+ if not isinstance(obj, arr.__class__):
+ assert_equal(getattr(arr, op_name)(obj),
+ NotImplemented, err_msg=err_msg)
+
+ # Check that the Python binops have priority
+ assert_equal(op(obj, arr), "op", err_msg=err_msg)
+ if op_name == rop_name:
+ assert_equal(op(arr, obj), "op", err_msg=err_msg)
+ else:
+ assert_equal(op(arr, obj), "rop", err_msg=err_msg)
+
+ # Check that Python binops have priority also for in-place ops
+ if has_iop:
+ assert_equal(getattr(arr, iop_name)(obj),
+ NotImplemented, err_msg=err_msg)
+ if op_name != "__pow__":
+ # inplace pow requires the other object to be
+ # integer-like?
+ assert_equal(iop(arr, obj), "rop", err_msg=err_msg)
+
+ # Check that ufunc call __numpy_ufunc__ normally
+ if np_op is not None:
+ assert_raises(AssertionError, np_op, arr, obj,
+ err_msg=err_msg)
+ assert_raises(AssertionError, np_op, obj, arr,
+ err_msg=err_msg)
+
+ # Check all binary operations
+ for op_name in sorted(ops.keys()):
+ yield check, op_name, True
+ yield check, op_name, False
+
+ def test_ufunc_override_rop_simple(self):
+ # Check parts of the binary op overriding behavior in an
+ # explicit test case that is easier to understand.
+
+ class SomeClass(object):
+ def __numpy_ufunc__(self, *a, **kw):
+ return "ufunc"
+ def __mul__(self, other):
+ return 123
+ def __rmul__(self, other):
+ return 321
+ def __gt__(self, other):
+ return "yep"
+ def __lt__(self, other):
+ return "nope"
+
+ class SomeClass2(SomeClass, ndarray):
+ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
+ if ufunc is np.multiply:
+ return "ufunc"
+ else:
+ inputs = list(inputs)
+ inputs[i] = np.asarray(self)
+ func = getattr(ufunc, method)
+ r = func(*inputs, **kw)
+ if 'out' in kw:
+ return r
+ else:
+ x = SomeClass2(r.shape, dtype=r.dtype)
+ x[...] = r
+ return x
+
+ arr = np.array([0])
+ obj = SomeClass()
+ obj2 = SomeClass2((1,), dtype=np.int_)
+ obj2[0] = 9
+
+ assert_equal(obj * arr, 123)
+ assert_equal(arr * obj, 321)
+ assert_equal(arr > obj, "nope")
+ assert_equal(arr < obj, "yep")
+ assert_equal(np.multiply(arr, obj), "ufunc")
+ arr *= obj
+ assert_equal(arr, 321)
+
+ assert_equal(obj2 * arr, 123)
+ assert_equal(arr * obj2, 321)
+ assert_equal(arr > obj2, "nope")
+ assert_equal(arr < obj2, "yep")
+ assert_equal(np.multiply(arr, obj2), "ufunc")
+ arr *= obj2
+ assert_equal(arr, 321)
+
+ obj2 += 33
+ assert_equal(obj2[0], 42)
+ assert_equal(obj2.sum(), 42)
+ assert_(isinstance(obj2, SomeClass2))
+
+
class TestSubscripting(TestCase):
def test_test_zero_rank(self):
x = array([1, 2, 3])