summaryrefslogtreecommitdiff
path: root/numpy/ma/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/tests')
-rw-r--r--numpy/ma/tests/test_core.py43
-rw-r--r--numpy/ma/tests/test_subclassing.py93
2 files changed, 131 insertions, 5 deletions
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index f8a28164e..6df235e9f 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -275,6 +275,49 @@ class TestMaskedArray(TestCase):
assert_equal(s1, s2)
assert_(x1[1:1].shape == (0,))
+ def test_matrix_indexing(self):
+ # Tests conversions and indexing
+ x1 = np.matrix([[1, 2, 3], [4, 3, 2]])
+ x2 = array(x1, mask=[[1, 0, 0], [0, 1, 0]])
+ x3 = array(x1, mask=[[0, 1, 0], [1, 0, 0]])
+ x4 = array(x1)
+ # test conversion to strings
+ junk, garbage = str(x2), repr(x2)
+ # assert_equal(np.sort(x1), sort(x2, endwith=False))
+ # tests of indexing
+ assert_(type(x2[1, 0]) is type(x1[1, 0]))
+ assert_(x1[1, 0] == x2[1, 0])
+ assert_(x2[1, 1] is masked)
+ assert_equal(x1[0, 2], x2[0, 2])
+ assert_equal(x1[0, 1:], x2[0, 1:])
+ assert_equal(x1[:, 2], x2[:, 2])
+ assert_equal(x1[:], x2[:])
+ assert_equal(x1[1:], x3[1:])
+ x1[0, 2] = 9
+ x2[0, 2] = 9
+ assert_equal(x1, x2)
+ x1[0, 1:] = 99
+ x2[0, 1:] = 99
+ assert_equal(x1, x2)
+ x2[0, 1] = masked
+ assert_equal(x1, x2)
+ x2[0, 1:] = masked
+ assert_equal(x1, x2)
+ x2[0, :] = x1[0, :]
+ x2[0, 1] = masked
+ assert_(allequal(getmask(x2), np.array([[0, 1, 0], [0, 1, 0]])))
+ x3[1, :] = masked_array([1, 2, 3], [1, 1, 0])
+ assert_(allequal(getmask(x3)[1], array([1, 1, 0])))
+ assert_(allequal(getmask(x3[1]), array([1, 1, 0])))
+ x4[1, :] = masked_array([1, 2, 3], [1, 1, 0])
+ assert_(allequal(getmask(x4[1]), array([1, 1, 0])))
+ assert_(allequal(x4[1], array([1, 2, 3])))
+ x1 = np.matrix(np.arange(5) * 1.0)
+ x2 = masked_values(x1, 3.0)
+ assert_equal(x1, x2)
+ assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask))
+ assert_equal(3.0, x2.fill_value)
+
def test_copy(self):
# Tests of some subtle points of copying and sizing.
n = [0, 0, 1, 0, 0]
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index ade5c59da..07fc8fdd6 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -84,20 +84,71 @@ mmatrix = MMatrix
# also a subclass that overrides __str__, __repr__ and __setitem__, disallowing
# setting to non-class values (and thus np.ma.core.masked_print_option)
+class CSAIterator(object):
+ """
+ Flat iterator object that uses its own setter/getter
+ (works around ndarray.flat not propagating subclass setters/getters
+ see https://github.com/numpy/numpy/issues/4564)
+ roughly following MaskedIterator
+ """
+ def __init__(self, a):
+ self._original = a
+ self._dataiter = a.view(np.ndarray).flat
+
+ def __iter__(self):
+ return self
+
+ def __getitem__(self, indx):
+ out = self._dataiter.__getitem__(indx)
+ if not isinstance(out, np.ndarray):
+ out = out.__array__()
+ out = out.view(type(self._original))
+ return out
+
+ def __setitem__(self, index, value):
+ self._dataiter[index] = self._original._validate_input(value)
+
+ def __next__(self):
+ return next(self._dataiter).__array__().view(type(self._original))
+
+ next = __next__
+
+
class ComplicatedSubArray(SubArray):
+
def __str__(self):
- return 'myprefix {0} mypostfix'.format(
- super(ComplicatedSubArray, self).__str__())
+ return 'myprefix {0} mypostfix'.format(self.view(SubArray))
def __repr__(self):
# Return a repr that does not start with 'name('
return '<{0} {1}>'.format(self.__class__.__name__, self)
- def __setitem__(self, item, value):
- # this ensures direct assignment to masked_print_option will fail
+ def _validate_input(self, value):
if not isinstance(value, ComplicatedSubArray):
raise ValueError("Can only set to MySubArray values")
- super(ComplicatedSubArray, self).__setitem__(item, value)
+ return value
+
+ def __setitem__(self, item, value):
+ # validation ensures direct assignment with ndarray or
+ # masked_print_option will fail
+ super(ComplicatedSubArray, self).__setitem__(
+ item, self._validate_input(value))
+
+ def __getitem__(self, item):
+ # ensure getter returns our own class also for scalars
+ value = super(ComplicatedSubArray, self).__getitem__(item)
+ if not isinstance(value, np.ndarray): # scalar
+ value = value.__array__().view(ComplicatedSubArray)
+ return value
+
+ @property
+ def flat(self):
+ return CSAIterator(self)
+
+ @flat.setter
+ def flat(self, value):
+ y = self.ravel()
+ y[:] = value
class TestSubclassing(TestCase):
@@ -205,6 +256,38 @@ class TestSubclassing(TestCase):
assert_equal(mxsub.info, xsub.info)
assert_equal(mxsub._mask, m)
+ def test_subclass_items(self):
+ """test that getter and setter go via baseclass"""
+ x = np.arange(5)
+ xcsub = ComplicatedSubArray(x)
+ mxcsub = masked_array(xcsub, mask=[True, False, True, False, False])
+ # getter should return a ComplicatedSubArray, even for single item
+ # first check we wrote ComplicatedSubArray correctly
+ self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray))
+ # now that it propagates inside the MaskedArray
+ self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray))
+ # also for flattened version (which goes via MaskedIterator)
+ self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray))
+
+ # setter should only work with ComplicatedSubArray input
+ # first check we wrote ComplicatedSubArray correctly
+ assert_raises(ValueError, xcsub.__setitem__, 1, x[4])
+ # now that it propagates inside the MaskedArray
+ assert_raises(ValueError, mxcsub.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4])
+ mxcsub[1] = xcsub[4]
+ mxcsub[1:4] = xcsub[1:4]
+ # also for flattened version (which goes via MaskedIterator)
+ assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4])
+ mxcsub.flat[1] = xcsub[4]
+ mxcsub.flat[1:4] = xcsub[1:4]
+
def test_subclass_repr(self):
"""test that repr uses the name of the subclass
and 'array' for np.ndarray"""