summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-09-28 01:09:28 -0700
committerEric Wieser <wieser.eric@gmail.com>2017-09-28 01:09:28 -0700
commitdd86461e045557d87a487d1edab9617f0ed4b58f (patch)
treee57248236def7f2a9a4f84c44e04f88bc49d4290 /numpy/ma
parent5289102a3bed26744fc694a819f4b25be85b3170 (diff)
downloadnumpy-dd86461e045557d87a487d1edab9617f0ed4b58f.tar.gz
BUG: Don't ignore mismatching shapes just because the mask is all zeros
Fixes #4520
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py4
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 8 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index be605f0a7..b6e2edf5a 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -1920,7 +1920,7 @@ def masked_where(condition, a, copy=True):
"""
# Make sure that condition is a valid standard-type mask.
- cond = make_mask(condition)
+ cond = make_mask(condition, shrink=False)
a = np.array(a, copy=copy, subok=True)
(cshape, ashape) = (cond.shape, a.shape)
@@ -1934,7 +1934,7 @@ def masked_where(condition, a, copy=True):
cls = MaskedArray
result = a.view(cls)
# Assign to *.mask so that structured masks are handled correctly.
- result.mask = cond
+ result.mask = _shrink_mask(cond)
return result
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 654601cee..8d4284140 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3844,6 +3844,12 @@ class TestMaskedArrayFunctions(object):
assert_equal(am["A"],
np.ma.masked_array(np.zeros(10), np.ones(10)))
+ def test_masked_where_mismatch(self):
+ # gh-4520
+ x = np.arange(10)
+ y = np.arange(5)
+ assert_raises(IndexError, np.ma.masked_where, y > 6, x)
+
def test_masked_otherfunctions(self):
assert_equal(masked_inside(list(range(5)), 1, 3),
[0, 199, 199, 199, 4])