From 31b18be11e48894738aa58620ac2d2307c8b4932 Mon Sep 17 00:00:00 2001 From: pierregm Date: Tue, 12 Aug 2008 21:12:14 +0000 Subject: * masked_where : force a consistency check on the shapes of the inputs --- numpy/ma/core.py | 5 +++++ numpy/ma/tests/test_core.py | 11 +++++++++++ 2 files changed, 16 insertions(+) (limited to 'numpy') diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 0b5372fa3..a31cbef1b 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -884,6 +884,11 @@ def masked_where(condition, a, copy=True): """ cond = make_mask(condition) a = np.array(a, copy=copy, subok=True) + + (cshape, ashape) = (cond.shape, a.shape) + if cshape and cshape != ashape: + raise IndexError("Inconsistant shape between the condition and the input"\ + " (got %s and %s)" % (cshape, ashape)) if hasattr(a, '_mask'): cond = mask_or(cond, a._mask) cls = type(a) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 155d0bf73..89abde60a 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1969,6 +1969,17 @@ class TestMaskedArrayFunctions(TestCase): ctest = masked_where(btest,atest) assert_equal(atest,ctest) + def test_masked_where_shape_constraint(self): + a = arange(10) + try: + test = masked_equal(1, a) + except IndexError: + pass + else: + raise AssertionError("Should have failed...") + test = masked_equal(a,1) + assert(test.mask, [0,1,0,0,0,0,0,0,0,0]) + def test_masked_otherfunctions(self): assert_equal(masked_inside(range(5), 1, 3), [0, 199, 199, 199, 4]) -- cgit v1.2.1