diff options
-rw-r--r-- | numpy/ma/core.py | 2 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 8 |
2 files changed, 10 insertions, 0 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 3e12d22e4..fe3c03789 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -2534,6 +2534,8 @@ class MaskedIterator(object): if self.maskiter is not None: _mask = self.maskiter.__getitem__(indx) if isinstance(_mask, ndarray): + # set shape to match that of data; this is needed for matrices + _mask.shape = result.shape result._mask = _mask elif isinstance(_mask, np.void): return mvoid(result, mask=_mask, hardmask=self.ma._hardmask) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 65311313b..4a39103b2 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1305,6 +1305,7 @@ class TestMaskedArrayAttributes(TestCase): assert_equal(a.mask, nomask) def test_flat(self): + # Test that flat can return all types of items [#4585, #4615] # test simple access test = masked_array(np.matrix([[1, 2, 3]]), mask=[0, 0, 1]) assert_equal(test.flat[1], 2) @@ -1349,6 +1350,13 @@ class TestMaskedArrayAttributes(TestCase): if i >= x.shape[-1]: i = 0 j += 1 + # test that matrices keep the correct shape (#4615) + a = masked_array(np.matrix(np.eye(2)), mask=0) + b = a.flat + b01 = b[:2] + assert_equal(b01.data, array([[1., 0.]])) + assert_equal(b01.mask, array([[False, False]])) + #------------------------------------------------------------------------------ class TestFillingValues(TestCase): |