diff options
author | pierregm <pierregm@localhost> | 2009-01-08 20:02:29 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2009-01-08 20:02:29 +0000 |
commit | a5da87c6254f183dc068d912dde9fd00341d926f (patch) | |
tree | 2f670af8704ef7d67abacbb6a0e0f7ac3630115d /numpy | |
parent | 99f428eeb52ce65e79defff4306aebc601ee1c42 (diff) | |
download | numpy-a5da87c6254f183dc068d912dde9fd00341d926f.tar.gz |
* Add __eq__ and __ne__ for support of flexible arrays.
* Fixed .filled for nested structures
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 89 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 44 |
2 files changed, 129 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index cb3de4cf9..b0da8c200 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -857,6 +857,7 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1) #####-------------------------------------------------------------------------- #---- --- Mask creation functions --- #####-------------------------------------------------------------------------- + def _recursive_make_descr(datatype, newtype=bool_): "Private function allowing recursion in make_descr." # Do we have some name fields ? @@ -1134,6 +1135,7 @@ def masked_where(condition, a, copy=True): result._mask = cond return result + def masked_greater(x, value, copy=True): """ Return the array `x` masked where (x > value). @@ -1142,22 +1144,27 @@ def masked_greater(x, value, copy=True): """ return masked_where(greater(x, value), x, copy=copy) + def masked_greater_equal(x, value, copy=True): "Shortcut to masked_where, with condition = (x >= value)." return masked_where(greater_equal(x, value), x, copy=copy) + def masked_less(x, value, copy=True): "Shortcut to masked_where, with condition = (x < value)." return masked_where(less(x, value), x, copy=copy) + def masked_less_equal(x, value, copy=True): "Shortcut to masked_where, with condition = (x <= value)." return masked_where(less_equal(x, value), x, copy=copy) + def masked_not_equal(x, value, copy=True): "Shortcut to masked_where, with condition = (x != value)." return masked_where(not_equal(x, value), x, copy=copy) + def masked_equal(x, value, copy=True): """ Shortcut to masked_where, with condition = (x == value). For @@ -1171,6 +1178,7 @@ def masked_equal(x, value, copy=True): # return array(d, mask=m, copy=copy) return masked_where(equal(x, value), x, copy=copy) + def masked_inside(x, v1, v2, copy=True): """ Shortcut to masked_where, where ``condition`` is True for x inside @@ -1188,6 +1196,7 @@ def masked_inside(x, v1, v2, copy=True): condition = (xf >= v1) & (xf <= v2) return masked_where(condition, x, copy=copy) + def masked_outside(x, v1, v2, copy=True): """ Shortcut to ``masked_where``, where ``condition`` is True for x outside @@ -1205,7 +1214,7 @@ def masked_outside(x, v1, v2, copy=True): condition = (xf < v1) | (xf > v2) return masked_where(condition, x, copy=copy) -# + def masked_object(x, value, copy=True, shrink=True): """ Mask the array `x` where the data are exactly equal to value. @@ -1234,6 +1243,7 @@ def masked_object(x, value, copy=True, shrink=True): mask = mask_or(mask, make_mask(condition, shrink=shrink)) return masked_array(x, mask=mask, copy=copy, fill_value=value) + def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True): """ Mask the array x where the data are approximately equal in @@ -1271,6 +1281,7 @@ def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True): mask = mask_or(mask, make_mask(condition, shrink=shrink)) return masked_array(xnew, mask=mask, copy=copy, fill_value=value) + def masked_invalid(a, copy=True): """ Mask the array for invalid values (NaNs or infs). @@ -1292,6 +1303,7 @@ def masked_invalid(a, copy=True): #####-------------------------------------------------------------------------- #---- --- Printing options --- #####-------------------------------------------------------------------------- + class _MaskedPrintOption: """ Handle the string used to represent missing data in a masked array. @@ -1372,6 +1384,20 @@ masked_%(name)s(data = %(data)s, #---- --- MaskedArray class --- #####-------------------------------------------------------------------------- +def _recursive_filled(a, mask, fill_value): + """ + Recursively fill `a` with `fill_value`. + Private function + """ + names = a.dtype.names + for name in names: + current = a[name] + print "Name: %s : %s" % (name, current) + if current.dtype.names: + _recursive_filled(current, mask[name], fill_value[name]) + else: + np.putmask(current, mask[name], fill_value[name]) + #............................................................................... class _arraymethod(object): """ @@ -2013,6 +2039,7 @@ class MaskedArray(ndarray): try: return _mask.view((bool_, len(self.dtype))).all(axis) except ValueError: + # In case we have nested fields... return np.all([[f[n].all() for n in _mask.dtype.names] for f in _mask], axis=axis) @@ -2106,6 +2133,7 @@ class MaskedArray(ndarray): fill_value = property(fget=get_fill_value, fset=set_fill_value, doc="Filling value.") + def filled(self, fill_value=None): """Return a copy of self._data, where masked values are filled with fill_value. @@ -2140,9 +2168,10 @@ class MaskedArray(ndarray): # if m.dtype.names: result = self._data.copy() - for n in result.dtype.names: - field = result[n] - np.putmask(field, self._mask[n], fill_value[n]) + _recursive_filled(result, self._mask, fill_value) +# for n in result.dtype.names: +# field = result[n] +# np.putmask(field, self._mask[n], fill_value[n]) elif not m.any(): return self._data else: @@ -2287,6 +2316,58 @@ class MaskedArray(ndarray): return _print_templates['short'] % parameters return _print_templates['long'] % parameters #............................................ + def __eq__(self, other): + "Check whether other equals self elementwise" + omask = getattr(other, '_mask', nomask) + if omask is nomask: + check = ndarray.__eq__(self.filled(0), other).view(type(self)) + check._mask = self._mask + else: + odata = filled(other, 0) + check = ndarray.__eq__(self.filled(0), odata).view(type(self)) + if self._mask is nomask: + check._mask = omask + else: + mask = mask_or(self._mask, omask) + if mask.dtype.names: + if mask.size > 1: + axis = 1 + else: + axis = None + try: + mask = mask.view((bool_, len(self.dtype))).all(axis) + except ValueError: + mask = np.all([[f[n].all() for n in mask.dtype.names] + for f in mask], axis=axis) + check._mask = mask + return check + # + def __ne__(self, other): + "Check whether other doesn't equal self elementwise" + omask = getattr(other, '_mask', nomask) + if omask is nomask: + check = ndarray.__ne__(self.filled(0), other).view(type(self)) + check._mask = self._mask + else: + odata = filled(other, 0) + check = ndarray.__ne__(self.filled(0), odata).view(type(self)) + if self._mask is nomask: + check._mask = omask + else: + mask = mask_or(self._mask, omask) + if mask.dtype.names: + if mask.size > 1: + axis = 1 + else: + axis = None + try: + mask = mask.view((bool_, len(self.dtype))).all(axis) + except ValueError: + mask = np.all([[f[n].all() for n in mask.dtype.names] + for f in mask], axis=axis) + check._mask = mask + return check + # def __add__(self, other): "Add other to self, and return a new masked array." return add(self, other) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index dbb56319f..9ac4342e6 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -474,6 +474,16 @@ class TestMaskedArray(TestCase): np.array([(1, '1', 1.)], dtype=flexi.dtype)) + def test_filled_w_nested_dtype(self): + "Test filled w/ nested dtype" + ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] + a = array([(1, (1, 1)), (2, (2, 2))], + mask=[(0, (1, 0)), (0, (0, 1))], dtype=ndtype) + test = a.filled(0) + control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype) + assert_equal(test, control) + + def test_optinfo_propagation(self): "Checks that _optinfo dictionary isn't back-propagated" x = array([1,2,3,], dtype=float) @@ -884,6 +894,40 @@ class TestMaskedArrayArithmetic(TestCase): self.failUnless(output[0] is masked) + def test_eq_on_structured(self): + "Test the equality of structured arrays" + ndtype = [('A', int), ('B', int)] + a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + test = (a == a) + assert_equal(test, [True, True]) + assert_equal(test.mask, [False, False]) + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) + test = (a == b) + assert_equal(test, [False, True]) + assert_equal(test.mask, [True, False]) + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) + test = (a == b) + assert_equal(test, [True, False]) + assert_equal(test.mask, [False, False]) + + + def test_ne_on_structured(self): + "Test the equality of structured arrays" + ndtype = [('A', int), ('B', int)] + a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + test = (a != a) + assert_equal(test, [False, False]) + assert_equal(test.mask, [False, False]) + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) + test = (a != b) + assert_equal(test, [True, False]) + assert_equal(test.mask, [True, False]) + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) + test = (a != b) + assert_equal(test, [False, True]) + assert_equal(test.mask, [False, False]) + + #------------------------------------------------------------------------------ class TestMaskedArrayAttributes(TestCase): |