summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/number.c4
-rw-r--r--numpy/ma/core.py24
-rw-r--r--numpy/ma/tests/test_core.py18
3 files changed, 37 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index a2160afd8..3819f510b 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -303,6 +303,10 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op)
* See also:
* - https://github.com/numpy/numpy/issues/3502
* - https://github.com/numpy/numpy/issues/3503
+ *
+ * NB: there's another copy of this code in
+ * numpy.ma.core.MaskedArray._delegate_binop
+ * which should possibly be updated when this is.
*/
double m1_prio = PyArray_GetPriority((PyObject *)m1,
NPY_SCALAR_PRIORITY);
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 5df928a6d..877d07e02 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3679,6 +3679,16 @@ class MaskedArray(ndarray):
return _print_templates['short_std'] % parameters
return _print_templates['long_std'] % parameters
+ def _delegate_binop(self, other):
+ # This emulates the logic in
+ # multiarray/number.c:PyArray_GenericBinaryFunction
+ if (not isinstance(other, np.ndarray)
+ and not hasattr(other, "__numpy_ufunc__")):
+ other_priority = getattr(other, "__array_priority__", -1000000)
+ if self.__array_priority__ < other_priority:
+ return True
+ return False
+
def __eq__(self, other):
"Check whether other equals self elementwise"
if self is masked:
@@ -3747,6 +3757,8 @@ class MaskedArray(ndarray):
#
def __add__(self, other):
"Add other to self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return add(self, other)
#
def __radd__(self, other):
@@ -3755,6 +3767,8 @@ class MaskedArray(ndarray):
#
def __sub__(self, other):
"Subtract other to self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return subtract(self, other)
#
def __rsub__(self, other):
@@ -3763,6 +3777,8 @@ class MaskedArray(ndarray):
#
def __mul__(self, other):
"Multiply other by self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return multiply(self, other)
#
def __rmul__(self, other):
@@ -3771,10 +3787,14 @@ class MaskedArray(ndarray):
#
def __div__(self, other):
"Divide other into self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return divide(self, other)
#
def __truediv__(self, other):
"Divide other into self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return true_divide(self, other)
#
def __rtruediv__(self, other):
@@ -3783,6 +3803,8 @@ class MaskedArray(ndarray):
#
def __floordiv__(self, other):
"Divide other into self, and return a new masked array."
+ if self._delegate_binop(other):
+ return NotImplemented
return floor_divide(self, other)
#
def __rfloordiv__(self, other):
@@ -3791,6 +3813,8 @@ class MaskedArray(ndarray):
#
def __pow__(self, other):
"Raise self to the power other, masking the potential NaNs/Infs"
+ if self._delegate_binop(other):
+ return NotImplemented
return power(self, other)
#
def __rpow__(self, other):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 6df235e9f..909c27c90 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -12,6 +12,7 @@ import warnings
import sys
import pickle
from functools import reduce
+import operator
from nose.tools import assert_raises
@@ -1734,16 +1735,15 @@ class TestUfuncs(TestCase):
self.assertTrue(not isinstance(test.mask, MaskedArray))
def test_treatment_of_NotImplemented(self):
- # Check any NotImplemented returned by umath.<ufunc> is passed on
+ # Check that NotImplemented is returned at appropriate places
+
a = masked_array([1., 2.], mask=[1, 0])
- # basic tests for _MaskedBinaryOperation
- assert_(a.__mul__('abc') is NotImplemented)
- assert_(multiply.outer(a, 'abc') is NotImplemented)
- # and for _DomainedBinaryOperation
- assert_(a.__div__('abc') is NotImplemented)
-
- # also check explicitly that rmul of another class can be accessed
- class MyClass(str):
+ self.assertRaises(TypeError, operator.mul, a, "abc")
+ self.assertRaises(TypeError, operator.truediv, a, "abc")
+
+ class MyClass(object):
+ __array_priority__ = a.__array_priority__ + 1
+
def __mul__(self, other):
return "My mul"