summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/ma/core.py5
-rw-r--r--numpy/ma/mrecords.py10
-rw-r--r--numpy/ma/tests/test_core.py10
3 files changed, 17 insertions, 8 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index a31cbef1b..dac79a155 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -916,9 +916,8 @@ def masked_less_equal(x, value, copy=True):
def masked_not_equal(x, value, copy=True):
"Shortcut to masked_where, with condition = (x != value)."
- return masked_where((x != value), x, copy=copy)
+ return masked_where(not_equal(x, value), x, copy=copy)
-#
def masked_equal(x, value, copy=True):
"""Shortcut to masked_where, with condition = (x == value). For
floating point, consider `masked_values(x, value)` instead.
@@ -929,7 +928,7 @@ def masked_equal(x, value, copy=True):
# c = umath.equal(d, value)
# m = mask_or(c, getmask(x))
# return array(d, mask=m, copy=copy)
- return masked_where((x == value), x, copy=copy)
+ return masked_where(equal(x, value), x, copy=copy)
def masked_inside(x, v1, v2, copy=True):
"""Shortcut to masked_where, where condition is True for x inside
diff --git a/numpy/ma/mrecords.py b/numpy/ma/mrecords.py
index 5e80695e7..e6735c015 100644
--- a/numpy/ma/mrecords.py
+++ b/numpy/ma/mrecords.py
@@ -158,10 +158,11 @@ class MaskedRecords(MaskedArray, object):
_fieldmask = getattr(obj, '_fieldmask', None)
if _fieldmask is None:
objmask = getattr(obj, '_mask', nomask)
+ _dtype = ndarray.__getattribute__(self,'dtype')
if objmask is nomask:
- _mask = ma.make_mask_none(self.shape, dtype=self.dtype)
+ _mask = ma.make_mask_none(self.shape, dtype=_dtype)
else:
- mdescr = ma.make_mask_descr(self.dtype)
+ mdescr = ma.make_mask_descr(_dtype)
_mask = narray([tuple([m]*len(mdescr)) for m in objmask],
dtype=mdescr).view(recarray)
else:
@@ -232,7 +233,7 @@ class MaskedRecords(MaskedArray, object):
self.__setmask__(val)
return
# Create a shortcut (so that we don't have to call getattr all the time)
- _localdict = self.__dict__
+ _localdict = object.__getattribute__(self, '__dict__')
# Check whether we're creating a new field
newattr = attr not in _localdict
try:
@@ -241,7 +242,8 @@ class MaskedRecords(MaskedArray, object):
except:
# Not a generic attribute: exit if it's not a valid field
fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
- if attr not in fielddict:
+ optinfo = ndarray.__getattribute__(self,'_optinfo') or {}
+ if not (attr in fielddict or attr in optinfo):
exctype, value = sys.exc_info()[:2]
raise exctype, value
else:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 89abde60a..5b5667ef5 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1942,13 +1942,21 @@ class TestMaskedArrayFunctions(TestCase):
xm.set_fill_value(1.e+20)
self.info = (xm, ym)
- #
def test_masked_where_bool(self):
x = [1,2]
y = masked_where(False,x)
assert_equal(y,[1,2])
assert_equal(y[1],2)
+ def test_masked_equal_wlist(self):
+ x = [1, 2, 3]
+ mx = masked_equal(x, 3)
+ assert_equal(mx, x)
+ assert_equal(mx._mask, [0,0,1])
+ mx = masked_not_equal(x, 3)
+ assert_equal(mx, x)
+ assert_equal(mx._mask, [1,1,0])
+
def test_masked_where_condition(self):
"Tests masking functions."
x = array([1.,2.,3.,4.,5.])