summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/tests/test_subclassing.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index 3491cef7f..91b4241b5 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -7,6 +7,7 @@
"""
import numpy as np
+from numpy.lib.mixins import NDArrayOperatorsMixin
from numpy.testing import assert_, assert_raises
from numpy.ma.testutils import assert_equal
from numpy.ma.core import (
@@ -147,6 +148,33 @@ class ComplicatedSubArray(SubArray):
return obj
+class WrappedArray(NDArrayOperatorsMixin):
+ """
+ Wrapping a MaskedArray rather than subclassing to test that
+ ufunc deferrals are commutative.
+ See: https://github.com/numpy/numpy/issues/15200)
+ """
+ __array_priority__ = 20
+
+ def __init__(self, array, **attrs):
+ self._array = array
+ self.attrs = attrs
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(\n{self._array}\n{self.attrs}\n)"
+
+ def __array__(self):
+ return np.asarray(self._array)
+
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+ if method == '__call__':
+ inputs = [arg._array if isinstance(arg, self.__class__) else arg
+ for arg in inputs]
+ return self.__class__(ufunc(*inputs, **kwargs), **self.attrs)
+ else:
+ return NotImplemented
+
+
class TestSubclassing:
# Test suite for masked subclasses of ndarray.
@@ -384,3 +412,34 @@ def test_array_no_inheritance():
# Test that the mask is False and not shared when keep_mask=False
assert_(not new_array.mask)
assert_(not new_array.sharedmask)
+
+
+class TestClassWrapping:
+ # Test suite for classes that wrap MaskedArrays
+
+ def setup(self):
+ m = np.ma.masked_array([1, 3, 5], mask=[False, True, False])
+ wm = WrappedArray(m)
+ self.data = (m, wm)
+
+ def test_masked_unary_operations(self):
+ # Tests masked_unary_operation
+ (m, wm) = self.data
+ with np.errstate(divide='ignore'):
+ assert_(isinstance(np.log(wm), WrappedArray))
+
+ def test_masked_binary_operations(self):
+ # Tests masked_binary_operation
+ (m, wm) = self.data
+ # Result should be a WrappedArray
+ assert_(isinstance(np.add(wm, wm), WrappedArray))
+ assert_(isinstance(np.add(m, wm), WrappedArray))
+ assert_(isinstance(np.add(wm, m), WrappedArray))
+ # add and '+' should call the same ufunc
+ assert_equal(np.add(m, wm), m + wm)
+ assert_(isinstance(np.hypot(m, wm), WrappedArray))
+ assert_(isinstance(np.hypot(wm, m), WrappedArray))
+ # Test domained binary operations
+ assert_(isinstance(np.divide(wm, m), WrappedArray))
+ assert_(isinstance(np.divide(m, wm), WrappedArray))
+ assert_equal(np.divide(wm, m) * m, np.divide(m, m) * wm)