diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-10-19 14:24:29 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-10-19 14:24:29 -0700 |
commit | 55d3ef5b81b709cc9d8ad6b13f75ca0f05f65276 (patch) | |
tree | 65cae9f774f7f9207e4b12f45c81771e6a4680f7 /numpy/core | |
parent | 1eb5f03c4c27009cac5d7c19f36bfab718533072 (diff) | |
parent | fd5d3088ed71bbc2fe5a774178be5e0ba04e4cd1 (diff) | |
download | numpy-55d3ef5b81b709cc9d8ad6b13f75ca0f05f65276.tar.gz |
Merge pull request #3856 from pv/op-before-ufunc
BUG: core: ensure __r*__ has precedence over __numpy_ufunc__
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 25 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.c | 85 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.h | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 183 |
4 files changed, 294 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 4c48ba673..b91e84366 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1273,10 +1273,19 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) switch (cmp_op) { case Py_LT: + if (needs_right_binop_forward(self, other, "__gt__", 0)) { + /* See discussion in number.c */ + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, other, n_ops.less); break; case Py_LE: + if (needs_right_binop_forward(self, other, "__ge__", 0)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, other, n_ops.less_equal); break; @@ -1285,6 +1294,10 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_False); return Py_False; } + if (needs_right_binop_forward(self, other, "__eq__", 0)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, (PyObject *)other, n_ops.equal); @@ -1343,6 +1356,10 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_True); return Py_True; } + if (needs_right_binop_forward(self, other, "__ne__", 0)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, (PyObject *)other, n_ops.not_equal); if (result && result != Py_NotImplemented) @@ -1392,10 +1409,18 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op) } break; case Py_GT: + if (needs_right_binop_forward(self, other, "__lt__", 0)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, other, n_ops.greater); break; case Py_GE: + if (needs_right_binop_forward(self, other, "__le__", 0)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } result = PyArray_GenericBinaryFunction(self, other, n_ops.greater_equal); break; diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index c1d04c536..2ad65bbe8 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -85,6 +85,62 @@ PyArray_SetNumericOps(PyObject *dict) (PyDict_SetItemString(dict, #op, n_ops.op)==-1)) \ goto fail; +/* + * Check whether the operation needs to be forwarded to the right-hand binary + * operation. + * + * This is the case when all of the following conditions apply: + * + * (i) the other object defines __numpy_ufunc__ + * (ii) the other object defines the right-hand operation __r*__ + * (iii) Python hasn't already called the right-hand operation + * [occurs if the other object is a strict subclass provided + * the operation is not in-place] + * + * This always prioritizes the __r*__ routines over __numpy_ufunc__, independent + * of whether the other object is an ndarray subclass or not. + */ + +NPY_NO_EXPORT int +needs_right_binop_forward(PyObject *self, PyObject *other, + char *right_name, int inplace_op) +{ + if (other == NULL || + self == NULL || + Py_TYPE(self) == Py_TYPE(other) || + PyArray_CheckExact(other) || + PyArray_CheckAnyScalar(other)) { + /* + * Quick cases + */ + return 0; + } + if (!inplace_op && PyType_IsSubtype(Py_TYPE(other), Py_TYPE(self)) || + !PyArray_Check(self)) { + /* + * Bail out if Python would already have called the right-hand + * operation. + */ + return 0; + } + if (PyObject_HasAttrString(other, "__numpy_ufunc__") && + PyObject_HasAttrString(other, right_name)) { + return 1; + } + else { + return 0; + } +} + +#define GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, left_name, right_name, inplace) \ + do { \ + if (needs_right_binop_forward(m1, m2, right_name, inplace)) { \ + Py_INCREF(Py_NotImplemented); \ + return Py_NotImplemented; \ + } \ + } while (0) + + /*NUMPY_API Get dictionary showing number functions that all arrays will use */ @@ -210,7 +266,7 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op) return Py_NotImplemented; } - if (!PyArray_Check(m2)) { + if (!PyArray_Check(m2) && !PyObject_HasAttrString(m2, "__numpy_ufunc__")) { /* * Catch priority inversion and punt, but only if it's guaranteed * that we were called through m1 and the other guy is not an array @@ -268,18 +324,21 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op) static PyObject * array_add(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__add__", "__radd__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.add); } static PyObject * array_subtract(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__sub__", "__rsub__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.subtract); } static PyObject * array_multiply(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mul__", "__rmul__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.multiply); } @@ -287,6 +346,7 @@ array_multiply(PyArrayObject *m1, PyObject *m2) static PyObject * array_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__div__", "__rdiv__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.divide); } #endif @@ -294,6 +354,7 @@ array_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_remainder(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__mod__", "__rmod__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder); } @@ -448,6 +509,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo)) { /* modulo is ignored! */ PyObject *value; + GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0); value = fast_scalar_power(a1, o2, 0); if (!value) { value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power); @@ -477,48 +539,56 @@ array_invert(PyArrayObject *m1) static PyObject * array_left_shift(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__lshift__", "__rlshift__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.left_shift); } static PyObject * array_right_shift(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__rshift__", "__rrshift__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.right_shift); } static PyObject * array_bitwise_and(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__and__", "__rand__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_and); } static PyObject * array_bitwise_or(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__or__", "__ror__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_or); } static PyObject * array_bitwise_xor(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__xor__", "__rxor__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.bitwise_xor); } static PyObject * array_inplace_add(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iadd__", "__radd__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.add); } static PyObject * array_inplace_subtract(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__isub__", "__rsub__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.subtract); } static PyObject * array_inplace_multiply(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imul__", "__rmul__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply); } @@ -526,6 +596,7 @@ array_inplace_multiply(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__idiv__", "__rdiv__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.divide); } #endif @@ -533,6 +604,7 @@ array_inplace_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_remainder(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__imod__", "__rmod__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder); } @@ -541,6 +613,7 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo { /* modulo is ignored! */ PyObject *value; + GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__ipow__", "__rpow__", 1); value = fast_scalar_power(a1, o2, 1); if (!value) { value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power); @@ -551,48 +624,56 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo static PyObject * array_inplace_left_shift(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ilshift__", "__rlshift__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.left_shift); } static PyObject * array_inplace_right_shift(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__irshift__", "__rrshift__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.right_shift); } static PyObject * array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__iand__", "__rand__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_and); } static PyObject * array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ior__", "__ror__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_or); } static PyObject * array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ixor__", "__rxor__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_xor); } static PyObject * array_floor_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__floordiv__", "__rfloordiv__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.floor_divide); } static PyObject * array_true_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__truediv__", "__rtruediv__", 0); return PyArray_GenericBinaryFunction(m1, m2, n_ops.true_divide); } static PyObject * array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__ifloordiv__", "__rfloordiv__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.floor_divide); } @@ -600,6 +681,7 @@ array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2) static PyObject * array_inplace_true_divide(PyArrayObject *m1, PyObject *m2) { + GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__itruediv__", "__rtruediv__", 1); return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.true_divide); } @@ -631,6 +713,7 @@ static PyObject * array_divmod(PyArrayObject *op1, PyObject *op2) { PyObject *divp, *modp, *result; + GIVE_UP_IF_HAS_RIGHT_BINOP(op1, op2, "__divmod__", "__rdivmod__", 0); divp = array_floor_divide(op1, op2); if (divp == NULL) { diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h index 0018b7348..63ea40696 100644 --- a/numpy/core/src/multiarray/number.h +++ b/numpy/core/src/multiarray/number.h @@ -69,4 +69,8 @@ NPY_NO_EXPORT PyObject * PyArray_GenericAccumulateFunction(PyArrayObject *m1, PyObject *op, int axis, int rtype, PyArrayObject *out); +NPY_NO_EXPORT int +needs_right_binop_forward(PyObject *self, PyObject *other, char *right_name, + int is_inplace); + #endif diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index c41c56b9c..cb02c3666 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4,12 +4,12 @@ import tempfile import sys import os import warnings +import operator if sys.version_info[0] >= 3: import builtins else: import __builtin__ as builtins - import numpy as np from nose import SkipTest from numpy.core import * @@ -1460,7 +1460,7 @@ class TestMethods(TestCase): class A(object): def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): return "A" - + class B(object): def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs): return NotImplemented @@ -1547,6 +1547,185 @@ class TestMethods(TestCase): assert_equal(a.ravel(order='K'), [2, 3, 0, 1]) assert_(a.ravel(order='K').flags.owndata) + +class TestBinop(object): + def test_ufunc_override_rop_precedence(self): + # Check that __rmul__ and other right-hand operations have + # precedence over __numpy_ufunc__ + + ops = { + '__add__': ('__radd__', np.add, True), + '__sub__': ('__rsub__', np.subtract, True), + '__mul__': ('__rmul__', np.multiply, True), + '__truediv__': ('__rtruediv__', np.true_divide, True), + '__floordiv__': ('__rfloordiv__', np.floor_divide, True), + '__mod__': ('__rmod__', np.remainder, True), + '__divmod__': ('__rdivmod__', None, False), + '__pow__': ('__rpow__', np.power, True), + '__lshift__': ('__rlshift__', np.left_shift, True), + '__rshift__': ('__rrshift__', np.right_shift, True), + '__and__': ('__rand__', np.bitwise_and, True), + '__xor__': ('__rxor__', np.bitwise_xor, True), + '__or__': ('__ror__', np.bitwise_or, True), + '__ge__': ('__le__', np.less_equal, False), + '__gt__': ('__lt__', np.less, False), + '__le__': ('__ge__', np.greater_equal, False), + '__lt__': ('__gt__', np.greater, False), + '__eq__': ('__eq__', np.equal, False), + '__ne__': ('__ne__', np.not_equal, False), + } + + class OtherNdarraySubclass(ndarray): + pass + + class OtherNdarraySubclassWithOverride(ndarray): + def __numpy_ufunc__(self, *a, **kw): + raise AssertionError(("__numpy_ufunc__ %r %r shouldn't have " + "been called!") % (a, kw)) + + def check(op_name, ndsubclass): + rop_name, np_op, has_iop = ops[op_name] + + if has_iop: + iop_name = '__i' + op_name[2:] + iop = getattr(operator, iop_name) + + if op_name == "__divmod__": + op = divmod + else: + op = getattr(operator, op_name) + + # Dummy class + def __init__(self, *a, **kw): + pass + + def __numpy_ufunc__(self, *a, **kw): + raise AssertionError(("__numpy_ufunc__ %r %r shouldn't have " + "been called!") % (a, kw)) + + def __op__(self, *other): + return "op" + + def __rop__(self, *other): + return "rop" + + if ndsubclass: + bases = (ndarray,) + else: + bases = (object,) + + dct = {'__init__': __init__, + '__numpy_ufunc__': __numpy_ufunc__, + op_name: __op__} + if op_name != rop_name: + dct[rop_name] = __rop__ + + cls = type("Rop" + rop_name, bases, dct) + + # Check behavior against both bare ndarray objects and a + # ndarray subclasses with and without their own override + obj = cls((1,)) + + arr_objs = [np.array([1]), + np.array([2]).view(OtherNdarraySubclass), + np.array([3]).view(OtherNdarraySubclassWithOverride), + ] + + for arr in arr_objs: + err_msg = "%r %r" % (op_name, arr,) + + # Check that ndarray op gives up if it sees a non-subclass + if not isinstance(obj, arr.__class__): + assert_equal(getattr(arr, op_name)(obj), + NotImplemented, err_msg=err_msg) + + # Check that the Python binops have priority + assert_equal(op(obj, arr), "op", err_msg=err_msg) + if op_name == rop_name: + assert_equal(op(arr, obj), "op", err_msg=err_msg) + else: + assert_equal(op(arr, obj), "rop", err_msg=err_msg) + + # Check that Python binops have priority also for in-place ops + if has_iop: + assert_equal(getattr(arr, iop_name)(obj), + NotImplemented, err_msg=err_msg) + if op_name != "__pow__": + # inplace pow requires the other object to be + # integer-like? + assert_equal(iop(arr, obj), "rop", err_msg=err_msg) + + # Check that ufunc call __numpy_ufunc__ normally + if np_op is not None: + assert_raises(AssertionError, np_op, arr, obj, + err_msg=err_msg) + assert_raises(AssertionError, np_op, obj, arr, + err_msg=err_msg) + + # Check all binary operations + for op_name in sorted(ops.keys()): + yield check, op_name, True + yield check, op_name, False + + def test_ufunc_override_rop_simple(self): + # Check parts of the binary op overriding behavior in an + # explicit test case that is easier to understand. + + class SomeClass(object): + def __numpy_ufunc__(self, *a, **kw): + return "ufunc" + def __mul__(self, other): + return 123 + def __rmul__(self, other): + return 321 + def __gt__(self, other): + return "yep" + def __lt__(self, other): + return "nope" + + class SomeClass2(SomeClass, ndarray): + def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw): + if ufunc is np.multiply: + return "ufunc" + else: + inputs = list(inputs) + inputs[i] = np.asarray(self) + func = getattr(ufunc, method) + r = func(*inputs, **kw) + if 'out' in kw: + return r + else: + x = SomeClass2(r.shape, dtype=r.dtype) + x[...] = r + return x + + arr = np.array([0]) + obj = SomeClass() + obj2 = SomeClass2((1,), dtype=np.int_) + obj2[0] = 9 + + assert_equal(obj * arr, 123) + assert_equal(arr * obj, 321) + assert_equal(arr > obj, "nope") + assert_equal(arr < obj, "yep") + assert_equal(np.multiply(arr, obj), "ufunc") + arr *= obj + assert_equal(arr, 321) + + assert_equal(obj2 * arr, 123) + assert_equal(arr * obj2, 321) + assert_equal(arr > obj2, "nope") + assert_equal(arr < obj2, "yep") + assert_equal(np.multiply(arr, obj2), "ufunc") + arr *= obj2 + assert_equal(arr, 321) + + obj2 += 33 + assert_equal(obj2[0], 42) + assert_equal(obj2.sum(), 42) + assert_(isinstance(obj2, SomeClass2)) + + class TestSubscripting(TestCase): def test_test_zero_rank(self): x = array([1, 2, 3]) |