diff options
author | Allan Haldane <ealloc@gmail.com> | 2017-09-28 23:35:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-28 23:35:35 +0200 |
commit | 767744c86e1e4a27bf77f90a85f17e564c9b94ea (patch) | |
tree | 24c460c6daa8dbdcc21e099c1a534debdb4737e9 | |
parent | 71d88205a2c3753e24b98813452d144adba56441 (diff) | |
parent | dd86461e045557d87a487d1edab9617f0ed4b58f (diff) | |
download | numpy-767744c86e1e4a27bf77f90a85f17e564c9b94ea.tar.gz |
Merge pull request #9785 from eric-wieser/fix-masked_where
BUG: Fix size-checking in masked_where, and structured shrink_mask
-rw-r--r-- | numpy/ma/core.py | 35 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 12 |
2 files changed, 32 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 67a813bf7..b6e2edf5a 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1537,6 +1537,16 @@ def is_mask(m): return False +def _shrink_mask(m): + """ + Shrink a mask to nomask if possible + """ + if not m.dtype.names and not m.any(): + return nomask + else: + return m + + def make_mask(m, copy=False, shrink=True, dtype=MaskType): """ Create a boolean mask from an array. @@ -1621,10 +1631,9 @@ def make_mask(m, copy=False, shrink=True, dtype=MaskType): # Fill the mask in case there are missing data; turn it into an ndarray. result = np.array(filled(m, True), copy=copy, dtype=dtype, subok=True) # Bas les masques ! - if shrink and (not result.dtype.names) and (not result.any()): - return nomask - else: - return result + if shrink: + result = _shrink_mask(result) + return result def make_mask_none(newshape, dtype=None): @@ -1911,7 +1920,7 @@ def masked_where(condition, a, copy=True): """ # Make sure that condition is a valid standard-type mask. - cond = make_mask(condition) + cond = make_mask(condition, shrink=False) a = np.array(a, copy=copy, subok=True) (cshape, ashape) = (cond.shape, a.shape) @@ -1925,7 +1934,7 @@ def masked_where(condition, a, copy=True): cls = MaskedArray result = a.view(cls) # Assign to *.mask so that structured masks are handled correctly. - result.mask = cond + result.mask = _shrink_mask(cond) return result @@ -3544,9 +3553,7 @@ class MaskedArray(ndarray): False """ - m = self._mask - if m.ndim and not m.any(): - self._mask = nomask + self._mask = _shrink_mask(self._mask) return self baseclass = property(fget=lambda self: self._baseclass, @@ -6659,12 +6666,11 @@ def concatenate(arrays, axis=0): return data # OK, so we have to concatenate the masks dm = np.concatenate([getmaskarray(a) for a in arrays], axis) + dm = dm.reshape(d.shape) + # If we decide to keep a '_shrinkmask' option, we want to check that # all of them are True, and then check for dm.any() - if not dm.dtype.fields and not dm.any(): - data._mask = nomask - else: - data._mask = dm.reshape(d.shape) + data._mask = _shrink_mask(dm) return data @@ -7082,8 +7088,7 @@ def where(condition, x=_NoValue, y=_NoValue): mask = np.where(cm, np.ones((), dtype=mask.dtype), mask) # collapse the mask, for backwards compatibility - if mask.dtype == np.bool_ and not mask.any(): - mask = nomask + mask = _shrink_mask(mask) return masked_array(data, mask=mask) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 4dd18182c..8d4284140 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1642,6 +1642,12 @@ class TestMaskedArrayAttributes(object): assert_equal(a, b) assert_equal(a.mask, nomask) + # Mask cannot be shrunk on structured types, so is a no-op + a = np.ma.array([(1, 2.0)], [('a', int), ('b', float)]) + b = a.copy() + a.shrink_mask() + assert_equal(a.mask, b.mask) + def test_flat(self): # Test that flat can return all types of items [#4585, #4615] # test simple access @@ -3838,6 +3844,12 @@ class TestMaskedArrayFunctions(object): assert_equal(am["A"], np.ma.masked_array(np.zeros(10), np.ones(10))) + def test_masked_where_mismatch(self): + # gh-4520 + x = np.arange(10) + y = np.arange(5) + assert_raises(IndexError, np.ma.masked_where, y > 6, x) + def test_masked_otherfunctions(self): assert_equal(masked_inside(list(range(5)), 1, 3), [0, 199, 199, 199, 4]) |