diff options
-rw-r--r-- | numpy/ma/core.py | 5 | ||||
-rw-r--r-- | numpy/ma/mrecords.py | 10 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 10 |
3 files changed, 17 insertions, 8 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index a31cbef1b..dac79a155 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -916,9 +916,8 @@ def masked_less_equal(x, value, copy=True): def masked_not_equal(x, value, copy=True): "Shortcut to masked_where, with condition = (x != value)." - return masked_where((x != value), x, copy=copy) + return masked_where(not_equal(x, value), x, copy=copy) -# def masked_equal(x, value, copy=True): """Shortcut to masked_where, with condition = (x == value). For floating point, consider `masked_values(x, value)` instead. @@ -929,7 +928,7 @@ def masked_equal(x, value, copy=True): # c = umath.equal(d, value) # m = mask_or(c, getmask(x)) # return array(d, mask=m, copy=copy) - return masked_where((x == value), x, copy=copy) + return masked_where(equal(x, value), x, copy=copy) def masked_inside(x, v1, v2, copy=True): """Shortcut to masked_where, where condition is True for x inside diff --git a/numpy/ma/mrecords.py b/numpy/ma/mrecords.py index 5e80695e7..e6735c015 100644 --- a/numpy/ma/mrecords.py +++ b/numpy/ma/mrecords.py @@ -158,10 +158,11 @@ class MaskedRecords(MaskedArray, object): _fieldmask = getattr(obj, '_fieldmask', None) if _fieldmask is None: objmask = getattr(obj, '_mask', nomask) + _dtype = ndarray.__getattribute__(self,'dtype') if objmask is nomask: - _mask = ma.make_mask_none(self.shape, dtype=self.dtype) + _mask = ma.make_mask_none(self.shape, dtype=_dtype) else: - mdescr = ma.make_mask_descr(self.dtype) + mdescr = ma.make_mask_descr(_dtype) _mask = narray([tuple([m]*len(mdescr)) for m in objmask], dtype=mdescr).view(recarray) else: @@ -232,7 +233,7 @@ class MaskedRecords(MaskedArray, object): self.__setmask__(val) return # Create a shortcut (so that we don't have to call getattr all the time) - _localdict = self.__dict__ + _localdict = object.__getattribute__(self, '__dict__') # Check whether we're creating a new field newattr = attr not in _localdict try: @@ -241,7 +242,8 @@ class MaskedRecords(MaskedArray, object): except: # Not a generic attribute: exit if it's not a valid field fielddict = ndarray.__getattribute__(self,'dtype').fields or {} - if attr not in fielddict: + optinfo = ndarray.__getattribute__(self,'_optinfo') or {} + if not (attr in fielddict or attr in optinfo): exctype, value = sys.exc_info()[:2] raise exctype, value else: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 89abde60a..5b5667ef5 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1942,13 +1942,21 @@ class TestMaskedArrayFunctions(TestCase): xm.set_fill_value(1.e+20) self.info = (xm, ym) - # def test_masked_where_bool(self): x = [1,2] y = masked_where(False,x) assert_equal(y,[1,2]) assert_equal(y[1],2) + def test_masked_equal_wlist(self): + x = [1, 2, 3] + mx = masked_equal(x, 3) + assert_equal(mx, x) + assert_equal(mx._mask, [0,0,1]) + mx = masked_not_equal(x, 3) + assert_equal(mx, x) + assert_equal(mx._mask, [1,1,0]) + def test_masked_where_condition(self): "Tests masking functions." x = array([1.,2.,3.,4.,5.]) |