summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2009-07-22 18:28:25 +0000
committerpierregm <pierregm@localhost>2009-07-22 18:28:25 +0000
commit09636b1ffd3b3ad55c6b631e609c156563b20d8f (patch)
tree6d67787734e5a9482793a08dbc85f60db732cb36 /numpy
parentd8be0c2ce965d6e34b307cc4295b40fe3aa07859 (diff)
downloadnumpy-09636b1ffd3b3ad55c6b631e609c156563b20d8f.tar.gz
* Fixed __eq__ and __ne__ on masked singleton
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py4
-rw-r--r--numpy/ma/tests/test_core.py10
2 files changed, 14 insertions, 0 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 4b4fdc2b9..5f6a7e260 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3171,6 +3171,8 @@ class MaskedArray(ndarray):
#............................................
def __eq__(self, other):
"Check whether other equals self elementwise"
+ if self is masked:
+ return masked
omask = getattr(other, '_mask', nomask)
if omask is nomask:
check = ndarray.__eq__(self.filled(0), other).view(type(self))
@@ -3197,6 +3199,8 @@ class MaskedArray(ndarray):
#
def __ne__(self, other):
"Check whether other doesn't equal self elementwise"
+ if self is masked:
+ return masked
omask = getattr(other, '_mask', nomask)
if omask is nomask:
check = ndarray.__ne__(self.filled(0), other).view(type(self))
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index b6222a8cb..ee2af1110 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -668,6 +668,15 @@ class TestMaskedArrayArithmetic(TestCase):
self.failUnless(minimum(xm, xm).mask)
+ def test_masked_singleton_equality(self):
+ "Tests (in)equality on masked snigleton"
+ a = array([1, 2, 3], mask=[1, 1, 0])
+ assert((a[0] == 0) is masked)
+ assert((a[0] != 0) is masked)
+ assert_equal((a[-1] == 0), False)
+ assert_equal((a[-1] != 0), True)
+
+
def test_arithmetic_with_masked_singleton(self):
"Checks that there's no collapsing to masked"
x = masked_array([1,2])
@@ -1067,6 +1076,7 @@ class TestMaskedArrayArithmetic(TestCase):
assert_equal(test.mask, control.mask)
assert_equal(a.mask, [0, 0, 0, 0, 1])
+
#------------------------------------------------------------------------------
class TestMaskedArrayAttributes(TestCase):