diff options
-rw-r--r-- | numpy/core/src/multiarray/number.c | 4 | ||||
-rw-r--r-- | numpy/ma/core.py | 24 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 18 |
3 files changed, 37 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index a2160afd8..3819f510b 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -303,6 +303,10 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op) * See also: * - https://github.com/numpy/numpy/issues/3502 * - https://github.com/numpy/numpy/issues/3503 + * + * NB: there's another copy of this code in + * numpy.ma.core.MaskedArray._delegate_binop + * which should possibly be updated when this is. */ double m1_prio = PyArray_GetPriority((PyObject *)m1, NPY_SCALAR_PRIORITY); diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 5df928a6d..877d07e02 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3679,6 +3679,16 @@ class MaskedArray(ndarray): return _print_templates['short_std'] % parameters return _print_templates['long_std'] % parameters + def _delegate_binop(self, other): + # This emulates the logic in + # multiarray/number.c:PyArray_GenericBinaryFunction + if (not isinstance(other, np.ndarray) + and not hasattr(other, "__numpy_ufunc__")): + other_priority = getattr(other, "__array_priority__", -1000000) + if self.__array_priority__ < other_priority: + return True + return False + def __eq__(self, other): "Check whether other equals self elementwise" if self is masked: @@ -3747,6 +3757,8 @@ class MaskedArray(ndarray): # def __add__(self, other): "Add other to self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return add(self, other) # def __radd__(self, other): @@ -3755,6 +3767,8 @@ class MaskedArray(ndarray): # def __sub__(self, other): "Subtract other to self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return subtract(self, other) # def __rsub__(self, other): @@ -3763,6 +3777,8 @@ class MaskedArray(ndarray): # def __mul__(self, other): "Multiply other by self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return multiply(self, other) # def __rmul__(self, other): @@ -3771,10 +3787,14 @@ class MaskedArray(ndarray): # def __div__(self, other): "Divide other into self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return divide(self, other) # def __truediv__(self, other): "Divide other into self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return true_divide(self, other) # def __rtruediv__(self, other): @@ -3783,6 +3803,8 @@ class MaskedArray(ndarray): # def __floordiv__(self, other): "Divide other into self, and return a new masked array." + if self._delegate_binop(other): + return NotImplemented return floor_divide(self, other) # def __rfloordiv__(self, other): @@ -3791,6 +3813,8 @@ class MaskedArray(ndarray): # def __pow__(self, other): "Raise self to the power other, masking the potential NaNs/Infs" + if self._delegate_binop(other): + return NotImplemented return power(self, other) # def __rpow__(self, other): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 6df235e9f..909c27c90 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -12,6 +12,7 @@ import warnings import sys import pickle from functools import reduce +import operator from nose.tools import assert_raises @@ -1734,16 +1735,15 @@ class TestUfuncs(TestCase): self.assertTrue(not isinstance(test.mask, MaskedArray)) def test_treatment_of_NotImplemented(self): - # Check any NotImplemented returned by umath.<ufunc> is passed on + # Check that NotImplemented is returned at appropriate places + a = masked_array([1., 2.], mask=[1, 0]) - # basic tests for _MaskedBinaryOperation - assert_(a.__mul__('abc') is NotImplemented) - assert_(multiply.outer(a, 'abc') is NotImplemented) - # and for _DomainedBinaryOperation - assert_(a.__div__('abc') is NotImplemented) - - # also check explicitly that rmul of another class can be accessed - class MyClass(str): + self.assertRaises(TypeError, operator.mul, a, "abc") + self.assertRaises(TypeError, operator.truediv, a, "abc") + + class MyClass(object): + __array_priority__ = a.__array_priority__ + 1 + def __mul__(self, other): return "My mul" |