summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAllan Haldane <ealloc@gmail.com>2017-09-28 23:35:35 +0200
committerGitHub <noreply@github.com>2017-09-28 23:35:35 +0200
commit767744c86e1e4a27bf77f90a85f17e564c9b94ea (patch)
tree24c460c6daa8dbdcc21e099c1a534debdb4737e9
parent71d88205a2c3753e24b98813452d144adba56441 (diff)
parentdd86461e045557d87a487d1edab9617f0ed4b58f (diff)
downloadnumpy-767744c86e1e4a27bf77f90a85f17e564c9b94ea.tar.gz
Merge pull request #9785 from eric-wieser/fix-masked_where
BUG: Fix size-checking in masked_where, and structured shrink_mask
-rw-r--r--numpy/ma/core.py35
-rw-r--r--numpy/ma/tests/test_core.py12
2 files changed, 32 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 67a813bf7..b6e2edf5a 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -1537,6 +1537,16 @@ def is_mask(m):
return False
+def _shrink_mask(m):
+ """
+ Shrink a mask to nomask if possible
+ """
+ if not m.dtype.names and not m.any():
+ return nomask
+ else:
+ return m
+
+
def make_mask(m, copy=False, shrink=True, dtype=MaskType):
"""
Create a boolean mask from an array.
@@ -1621,10 +1631,9 @@ def make_mask(m, copy=False, shrink=True, dtype=MaskType):
# Fill the mask in case there are missing data; turn it into an ndarray.
result = np.array(filled(m, True), copy=copy, dtype=dtype, subok=True)
# Bas les masques !
- if shrink and (not result.dtype.names) and (not result.any()):
- return nomask
- else:
- return result
+ if shrink:
+ result = _shrink_mask(result)
+ return result
def make_mask_none(newshape, dtype=None):
@@ -1911,7 +1920,7 @@ def masked_where(condition, a, copy=True):
"""
# Make sure that condition is a valid standard-type mask.
- cond = make_mask(condition)
+ cond = make_mask(condition, shrink=False)
a = np.array(a, copy=copy, subok=True)
(cshape, ashape) = (cond.shape, a.shape)
@@ -1925,7 +1934,7 @@ def masked_where(condition, a, copy=True):
cls = MaskedArray
result = a.view(cls)
# Assign to *.mask so that structured masks are handled correctly.
- result.mask = cond
+ result.mask = _shrink_mask(cond)
return result
@@ -3544,9 +3553,7 @@ class MaskedArray(ndarray):
False
"""
- m = self._mask
- if m.ndim and not m.any():
- self._mask = nomask
+ self._mask = _shrink_mask(self._mask)
return self
baseclass = property(fget=lambda self: self._baseclass,
@@ -6659,12 +6666,11 @@ def concatenate(arrays, axis=0):
return data
# OK, so we have to concatenate the masks
dm = np.concatenate([getmaskarray(a) for a in arrays], axis)
+ dm = dm.reshape(d.shape)
+
# If we decide to keep a '_shrinkmask' option, we want to check that
# all of them are True, and then check for dm.any()
- if not dm.dtype.fields and not dm.any():
- data._mask = nomask
- else:
- data._mask = dm.reshape(d.shape)
+ data._mask = _shrink_mask(dm)
return data
@@ -7082,8 +7088,7 @@ def where(condition, x=_NoValue, y=_NoValue):
mask = np.where(cm, np.ones((), dtype=mask.dtype), mask)
# collapse the mask, for backwards compatibility
- if mask.dtype == np.bool_ and not mask.any():
- mask = nomask
+ mask = _shrink_mask(mask)
return masked_array(data, mask=mask)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 4dd18182c..8d4284140 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1642,6 +1642,12 @@ class TestMaskedArrayAttributes(object):
assert_equal(a, b)
assert_equal(a.mask, nomask)
+ # Mask cannot be shrunk on structured types, so is a no-op
+ a = np.ma.array([(1, 2.0)], [('a', int), ('b', float)])
+ b = a.copy()
+ a.shrink_mask()
+ assert_equal(a.mask, b.mask)
+
def test_flat(self):
# Test that flat can return all types of items [#4585, #4615]
# test simple access
@@ -3838,6 +3844,12 @@ class TestMaskedArrayFunctions(object):
assert_equal(am["A"],
np.ma.masked_array(np.zeros(10), np.ones(10)))
+ def test_masked_where_mismatch(self):
+ # gh-4520
+ x = np.arange(10)
+ y = np.arange(5)
+ assert_raises(IndexError, np.ma.masked_where, y > 6, x)
+
def test_masked_otherfunctions(self):
assert_equal(masked_inside(list(range(5)), 1, 3),
[0, 199, 199, 199, 4])