diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-09-28 01:09:28 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-09-28 01:09:28 -0700 |
commit | dd86461e045557d87a487d1edab9617f0ed4b58f (patch) | |
tree | e57248236def7f2a9a4f84c44e04f88bc49d4290 /numpy/ma | |
parent | 5289102a3bed26744fc694a819f4b25be85b3170 (diff) | |
download | numpy-dd86461e045557d87a487d1edab9617f0ed4b58f.tar.gz |
BUG: Don't ignore mismatching shapes just because the mask is all zeros
Fixes #4520
Diffstat (limited to 'numpy/ma')
-rw-r--r-- | numpy/ma/core.py | 4 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
2 files changed, 8 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index be605f0a7..b6e2edf5a 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1920,7 +1920,7 @@ def masked_where(condition, a, copy=True): """ # Make sure that condition is a valid standard-type mask. - cond = make_mask(condition) + cond = make_mask(condition, shrink=False) a = np.array(a, copy=copy, subok=True) (cshape, ashape) = (cond.shape, a.shape) @@ -1934,7 +1934,7 @@ def masked_where(condition, a, copy=True): cls = MaskedArray result = a.view(cls) # Assign to *.mask so that structured masks are handled correctly. - result.mask = cond + result.mask = _shrink_mask(cond) return result diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 654601cee..8d4284140 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -3844,6 +3844,12 @@ class TestMaskedArrayFunctions(object): assert_equal(am["A"], np.ma.masked_array(np.zeros(10), np.ones(10))) + def test_masked_where_mismatch(self): + # gh-4520 + x = np.arange(10) + y = np.arange(5) + assert_raises(IndexError, np.ma.masked_where, y > 6, x) + def test_masked_otherfunctions(self): assert_equal(masked_inside(list(range(5)), 1, 3), [0, 199, 199, 199, 4]) |