diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 3 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 9 |
2 files changed, 11 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 51e9f0f28..b52dad9ac 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1818,7 +1818,8 @@ def masked_where(condition, a, copy=True): else: cls = MaskedArray result = a.view(cls) - result._mask = cond + # Assign to *.mask so that structured masks are handled correctly. + result.mask = cond return result diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index f0d5d6788..ea266669e 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -3242,6 +3242,15 @@ class TestMaskedArrayFunctions(TestCase): test = masked_equal(a, 1) assert_equal(test.mask, [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) + def test_masked_where_structured(self): + # test that masked_where on a structured array sets a structured + # mask (see issue #2972) + a = np.zeros(10, dtype=[("A", "<f2"), ("B", "<f4")]) + am = np.ma.masked_where(a["A"]<5, a) + assert_equal(am.mask.dtype.names, am.dtype.names) + assert_equal(am["A"], + np.ma.masked_array(np.zeros(10), np.ones(10))) + def test_masked_otherfunctions(self): assert_equal(masked_inside(list(range(5)), 1, 3), [0, 199, 199, 199, 4]) |