diff options
Diffstat (limited to 'numpy/ma/tests')
-rw-r--r-- | numpy/ma/tests/test_core.py | 35 | ||||
-rw-r--r-- | numpy/ma/tests/test_subclassing.py | 40 |
2 files changed, 65 insertions, 10 deletions
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 909c27c90..2bf782724 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -685,6 +685,18 @@ class TestMaskedArray(TestCase): finally: masked_print_option.set_display(ini_display) + def test_object_with_array(self): + mx1 = masked_array([1.], mask=[True]) + mx2 = masked_array([1., 2.]) + mx = masked_array([mx1, mx2], mask=[False, True]) + assert mx[0] is mx1 + assert mx[1] is not mx2 + assert np.all(mx[1].data == mx2.data) + assert np.all(mx[1].mask) + # check that we return a view. + mx[1].data[0] = 0. + assert mx2[0] == 0. + #------------------------------------------------------------------------------ class TestMaskedArrayArithmetic(TestCase): @@ -1754,6 +1766,29 @@ class TestUfuncs(TestCase): assert_(me * a == "My mul") assert_(a * me == "My rmul") + # and that __array_priority__ is respected + class MyClass2(object): + __array_priority__ = 100 + + def __mul__(self, other): + return "Me2mul" + + def __rmul__(self, other): + return "Me2rmul" + + def __rdiv__(self, other): + return "Me2rdiv" + + __rtruediv__ = __rdiv__ + + me_too = MyClass2() + assert_(a.__mul__(me_too) is NotImplemented) + assert_(all(multiply.outer(a, me_too) == "Me2rmul")) + assert_(a.__truediv__(me_too) is NotImplemented) + assert_(me_too * a == "Me2mul") + assert_(a * me_too == "Me2rmul") + assert_(a / me_too == "Me2rdiv") + #------------------------------------------------------------------------------ class TestMaskedArrayInPlaceArithmetics(TestCase): diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py index 07fc8fdd6..2e98348e6 100644 --- a/numpy/ma/tests/test_subclassing.py +++ b/numpy/ma/tests/test_subclassing.py @@ -24,18 +24,27 @@ class SubArray(np.ndarray): # in the dictionary `info`. def __new__(cls,arr,info={}): x = np.asanyarray(arr).view(cls) - x.info = info + x.info = info.copy() return x def __array_finalize__(self, obj): - self.info = getattr(obj, 'info', {}) + if callable(getattr(super(SubArray, self), + '__array_finalize__', None)): + super(SubArray, self).__array_finalize__(obj) + self.info = getattr(obj, 'info', {}).copy() return def __add__(self, other): - result = np.ndarray.__add__(self, other) - result.info.update({'added':result.info.pop('added', 0)+1}) + result = super(SubArray, self).__add__(other) + result.info['added'] = result.info.get('added', 0) + 1 return result + def __iadd__(self, other): + result = super(SubArray, self).__iadd__(other) + result.info['iadded'] = result.info.get('iadded', 0) + 1 + return result + + subarray = SubArray @@ -47,11 +56,6 @@ class MSubArray(SubArray, MaskedArray): _data.info = subarr.info return _data - def __array_finalize__(self, obj): - MaskedArray.__array_finalize__(self, obj) - SubArray.__array_finalize__(self, obj) - return - def _get_series(self): _view = self.view(MaskedArray) _view._sharedmask = False @@ -82,8 +86,11 @@ class MMatrix(MaskedArray, np.matrix,): mmatrix = MMatrix -# also a subclass that overrides __str__, __repr__ and __setitem__, disallowing +# Also a subclass that overrides __str__, __repr__ and __setitem__, disallowing # setting to non-class values (and thus np.ma.core.masked_print_option) +# and overrides __array_wrap__, updating the info dict, to check that this +# doesn't get destroyed by MaskedArray._update_from. But this one also needs +# its own iterator... class CSAIterator(object): """ Flat iterator object that uses its own setter/getter @@ -150,6 +157,13 @@ class ComplicatedSubArray(SubArray): y = self.ravel() y[:] = value + def __array_wrap__(self, obj, context=None): + obj = super(ComplicatedSubArray, self).__array_wrap__(obj, context) + if context is not None and context[0] is np.multiply: + obj.info['multiplied'] = obj.info.get('multiplied', 0) + 1 + + return obj + class TestSubclassing(TestCase): # Test suite for masked subclasses of ndarray. @@ -218,6 +232,12 @@ class TestSubclassing(TestCase): self.assertTrue(isinstance(z, MSubArray)) self.assertTrue(isinstance(z._data, SubArray)) self.assertTrue(z._data.info['added'] > 0) + # Test that inplace methods from data get used (gh-4617) + ym += 1 + self.assertTrue(isinstance(ym, MaskedArray)) + self.assertTrue(isinstance(ym, MSubArray)) + self.assertTrue(isinstance(ym._data, SubArray)) + self.assertTrue(ym._data.info['iadded'] > 0) # ym._set_mask([1, 0, 0, 0, 1]) assert_equal(ym._mask, [1, 0, 0, 0, 1]) |