summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-08-12 21:12:14 +0000
committerpierregm <pierregm@localhost>2008-08-12 21:12:14 +0000
commit31b18be11e48894738aa58620ac2d2307c8b4932 (patch)
tree6fd9381d71e7df2c0403dc437d587e4f732ef0a7
parent96a6fcc6218e2487662706ad646194c8c22b8cc6 (diff)
downloadnumpy-31b18be11e48894738aa58620ac2d2307c8b4932.tar.gz
* masked_where : force a consistency check on the shapes of the inputs
-rw-r--r--numpy/ma/core.py5
-rw-r--r--numpy/ma/tests/test_core.py11
2 files changed, 16 insertions, 0 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 0b5372fa3..a31cbef1b 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -884,6 +884,11 @@ def masked_where(condition, a, copy=True):
"""
cond = make_mask(condition)
a = np.array(a, copy=copy, subok=True)
+
+ (cshape, ashape) = (cond.shape, a.shape)
+ if cshape and cshape != ashape:
+ raise IndexError("Inconsistant shape between the condition and the input"\
+ " (got %s and %s)" % (cshape, ashape))
if hasattr(a, '_mask'):
cond = mask_or(cond, a._mask)
cls = type(a)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 155d0bf73..89abde60a 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1969,6 +1969,17 @@ class TestMaskedArrayFunctions(TestCase):
ctest = masked_where(btest,atest)
assert_equal(atest,ctest)
+ def test_masked_where_shape_constraint(self):
+ a = arange(10)
+ try:
+ test = masked_equal(1, a)
+ except IndexError:
+ pass
+ else:
+ raise AssertionError("Should have failed...")
+ test = masked_equal(a,1)
+ assert(test.mask, [0,1,0,0,0,0,0,0,0,0])
+
def test_masked_otherfunctions(self):
assert_equal(masked_inside(range(5), 1, 3), [0, 199, 199, 199, 4])