summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2010-06-29 17:57:15 +0000
committerpierregm <pierregm@localhost>2010-06-29 17:57:15 +0000
commit0a95b13f36932bafb8bbb0d7c32b0895f50f562e (patch)
treeccdda1d12c0ce774d23ec2f9b5cc8ef6074f52f1 /numpy/ma
parent9b86617974fee246df0ebba859102897f531f659 (diff)
downloadnumpy-0a95b13f36932bafb8bbb0d7c32b0895f50f562e.tar.gz
Fixed __eq__/__ne__ for scalars
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py18
-rw-r--r--numpy/ma/tests/test_core.py15
2 files changed, 29 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index be07a12e3..de7485638 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3584,8 +3584,13 @@ class MaskedArray(ndarray):
return masked
omask = getattr(other, '_mask', nomask)
if omask is nomask:
- check = ndarray.__eq__(self.filled(0), other).view(type(self))
- check._mask = self._mask
+ check = ndarray.__eq__(self.filled(0), other)
+ try:
+ check = check.view(type(self))
+ check._mask = self._mask
+ except AttributeError:
+ # Dang, we have a bool instead of an array: return the bool
+ return check
else:
odata = filled(other, 0)
check = ndarray.__eq__(self.filled(0), odata).view(type(self))
@@ -3612,8 +3617,13 @@ class MaskedArray(ndarray):
return masked
omask = getattr(other, '_mask', nomask)
if omask is nomask:
- check = ndarray.__ne__(self.filled(0), other).view(type(self))
- check._mask = self._mask
+ check = ndarray.__ne__(self.filled(0), other)
+ try:
+ check = check.view(type(self))
+ check._mask = self._mask
+ except AttributeError:
+ # In case check is a boolean (or a numpy.bool)
+ return check
else:
odata = filled(other, 0)
check = ndarray.__ne__(self.filled(0), odata).view(type(self))
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index f95d621db..908d7adc6 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1149,6 +1149,21 @@ class TestMaskedArrayArithmetic(TestCase):
assert_equal(test.mask, [False, False])
+ def test_eq_w_None(self):
+ a = array([1, 2], mask=False)
+ assert_equal(a == None, False)
+ assert_equal(a != None, True)
+ a = masked
+ assert_equal(a == None, masked)
+
+ def test_eq_w_scalar(self):
+ a = array(1)
+ assert_equal(a == 1, True)
+ assert_equal(a == 0, False)
+ assert_equal(a != 1, False)
+ assert_equal(a != 0, True)
+
+
def test_numpyarithmetics(self):
"Check that the mask is not back-propagated when using numpy functions"
a = masked_array([-1, 0, 1, 2, 3], mask=[0, 0, 0, 0, 1])