diff options
author | pierregm <pierregm@localhost> | 2009-07-22 18:28:25 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2009-07-22 18:28:25 +0000 |
commit | 09636b1ffd3b3ad55c6b631e609c156563b20d8f (patch) | |
tree | 6d67787734e5a9482793a08dbc85f60db732cb36 /numpy | |
parent | d8be0c2ce965d6e34b307cc4295b40fe3aa07859 (diff) | |
download | numpy-09636b1ffd3b3ad55c6b631e609c156563b20d8f.tar.gz |
* Fixed __eq__ and __ne__ on masked singleton
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 4 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 10 |
2 files changed, 14 insertions, 0 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 4b4fdc2b9..5f6a7e260 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3171,6 +3171,8 @@ class MaskedArray(ndarray): #............................................ def __eq__(self, other): "Check whether other equals self elementwise" + if self is masked: + return masked omask = getattr(other, '_mask', nomask) if omask is nomask: check = ndarray.__eq__(self.filled(0), other).view(type(self)) @@ -3197,6 +3199,8 @@ class MaskedArray(ndarray): # def __ne__(self, other): "Check whether other doesn't equal self elementwise" + if self is masked: + return masked omask = getattr(other, '_mask', nomask) if omask is nomask: check = ndarray.__ne__(self.filled(0), other).view(type(self)) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index b6222a8cb..ee2af1110 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -668,6 +668,15 @@ class TestMaskedArrayArithmetic(TestCase): self.failUnless(minimum(xm, xm).mask) + def test_masked_singleton_equality(self): + "Tests (in)equality on masked snigleton" + a = array([1, 2, 3], mask=[1, 1, 0]) + assert((a[0] == 0) is masked) + assert((a[0] != 0) is masked) + assert_equal((a[-1] == 0), False) + assert_equal((a[-1] != 0), True) + + def test_arithmetic_with_masked_singleton(self): "Checks that there's no collapsing to masked" x = masked_array([1,2]) @@ -1067,6 +1076,7 @@ class TestMaskedArrayArithmetic(TestCase): assert_equal(test.mask, control.mask) assert_equal(a.mask, [0, 0, 0, 0, 1]) + #------------------------------------------------------------------------------ class TestMaskedArrayAttributes(TestCase): |