summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2009-01-08 20:02:29 +0000
committerpierregm <pierregm@localhost>2009-01-08 20:02:29 +0000
commita5da87c6254f183dc068d912dde9fd00341d926f (patch)
tree2f670af8704ef7d67abacbb6a0e0f7ac3630115d /numpy
parent99f428eeb52ce65e79defff4306aebc601ee1c42 (diff)
downloadnumpy-a5da87c6254f183dc068d912dde9fd00341d926f.tar.gz
* Add __eq__ and __ne__ for support of flexible arrays.
* Fixed .filled for nested structures
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py89
-rw-r--r--numpy/ma/tests/test_core.py44
2 files changed, 129 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index cb3de4cf9..b0da8c200 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -857,6 +857,7 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1)
#####--------------------------------------------------------------------------
#---- --- Mask creation functions ---
#####--------------------------------------------------------------------------
+
def _recursive_make_descr(datatype, newtype=bool_):
"Private function allowing recursion in make_descr."
# Do we have some name fields ?
@@ -1134,6 +1135,7 @@ def masked_where(condition, a, copy=True):
result._mask = cond
return result
+
def masked_greater(x, value, copy=True):
"""
Return the array `x` masked where (x > value).
@@ -1142,22 +1144,27 @@ def masked_greater(x, value, copy=True):
"""
return masked_where(greater(x, value), x, copy=copy)
+
def masked_greater_equal(x, value, copy=True):
"Shortcut to masked_where, with condition = (x >= value)."
return masked_where(greater_equal(x, value), x, copy=copy)
+
def masked_less(x, value, copy=True):
"Shortcut to masked_where, with condition = (x < value)."
return masked_where(less(x, value), x, copy=copy)
+
def masked_less_equal(x, value, copy=True):
"Shortcut to masked_where, with condition = (x <= value)."
return masked_where(less_equal(x, value), x, copy=copy)
+
def masked_not_equal(x, value, copy=True):
"Shortcut to masked_where, with condition = (x != value)."
return masked_where(not_equal(x, value), x, copy=copy)
+
def masked_equal(x, value, copy=True):
"""
Shortcut to masked_where, with condition = (x == value). For
@@ -1171,6 +1178,7 @@ def masked_equal(x, value, copy=True):
# return array(d, mask=m, copy=copy)
return masked_where(equal(x, value), x, copy=copy)
+
def masked_inside(x, v1, v2, copy=True):
"""
Shortcut to masked_where, where ``condition`` is True for x inside
@@ -1188,6 +1196,7 @@ def masked_inside(x, v1, v2, copy=True):
condition = (xf >= v1) & (xf <= v2)
return masked_where(condition, x, copy=copy)
+
def masked_outside(x, v1, v2, copy=True):
"""
Shortcut to ``masked_where``, where ``condition`` is True for x outside
@@ -1205,7 +1214,7 @@ def masked_outside(x, v1, v2, copy=True):
condition = (xf < v1) | (xf > v2)
return masked_where(condition, x, copy=copy)
-#
+
def masked_object(x, value, copy=True, shrink=True):
"""
Mask the array `x` where the data are exactly equal to value.
@@ -1234,6 +1243,7 @@ def masked_object(x, value, copy=True, shrink=True):
mask = mask_or(mask, make_mask(condition, shrink=shrink))
return masked_array(x, mask=mask, copy=copy, fill_value=value)
+
def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
"""
Mask the array x where the data are approximately equal in
@@ -1271,6 +1281,7 @@ def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
mask = mask_or(mask, make_mask(condition, shrink=shrink))
return masked_array(xnew, mask=mask, copy=copy, fill_value=value)
+
def masked_invalid(a, copy=True):
"""
Mask the array for invalid values (NaNs or infs).
@@ -1292,6 +1303,7 @@ def masked_invalid(a, copy=True):
#####--------------------------------------------------------------------------
#---- --- Printing options ---
#####--------------------------------------------------------------------------
+
class _MaskedPrintOption:
"""
Handle the string used to represent missing data in a masked array.
@@ -1372,6 +1384,20 @@ masked_%(name)s(data = %(data)s,
#---- --- MaskedArray class ---
#####--------------------------------------------------------------------------
+def _recursive_filled(a, mask, fill_value):
+ """
+ Recursively fill `a` with `fill_value`.
+ Private function
+ """
+ names = a.dtype.names
+ for name in names:
+ current = a[name]
+ print "Name: %s : %s" % (name, current)
+ if current.dtype.names:
+ _recursive_filled(current, mask[name], fill_value[name])
+ else:
+ np.putmask(current, mask[name], fill_value[name])
+
#...............................................................................
class _arraymethod(object):
"""
@@ -2013,6 +2039,7 @@ class MaskedArray(ndarray):
try:
return _mask.view((bool_, len(self.dtype))).all(axis)
except ValueError:
+ # In case we have nested fields...
return np.all([[f[n].all() for n in _mask.dtype.names]
for f in _mask], axis=axis)
@@ -2106,6 +2133,7 @@ class MaskedArray(ndarray):
fill_value = property(fget=get_fill_value, fset=set_fill_value,
doc="Filling value.")
+
def filled(self, fill_value=None):
"""Return a copy of self._data, where masked values are filled
with fill_value.
@@ -2140,9 +2168,10 @@ class MaskedArray(ndarray):
#
if m.dtype.names:
result = self._data.copy()
- for n in result.dtype.names:
- field = result[n]
- np.putmask(field, self._mask[n], fill_value[n])
+ _recursive_filled(result, self._mask, fill_value)
+# for n in result.dtype.names:
+# field = result[n]
+# np.putmask(field, self._mask[n], fill_value[n])
elif not m.any():
return self._data
else:
@@ -2287,6 +2316,58 @@ class MaskedArray(ndarray):
return _print_templates['short'] % parameters
return _print_templates['long'] % parameters
#............................................
+ def __eq__(self, other):
+ "Check whether other equals self elementwise"
+ omask = getattr(other, '_mask', nomask)
+ if omask is nomask:
+ check = ndarray.__eq__(self.filled(0), other).view(type(self))
+ check._mask = self._mask
+ else:
+ odata = filled(other, 0)
+ check = ndarray.__eq__(self.filled(0), odata).view(type(self))
+ if self._mask is nomask:
+ check._mask = omask
+ else:
+ mask = mask_or(self._mask, omask)
+ if mask.dtype.names:
+ if mask.size > 1:
+ axis = 1
+ else:
+ axis = None
+ try:
+ mask = mask.view((bool_, len(self.dtype))).all(axis)
+ except ValueError:
+ mask = np.all([[f[n].all() for n in mask.dtype.names]
+ for f in mask], axis=axis)
+ check._mask = mask
+ return check
+ #
+ def __ne__(self, other):
+ "Check whether other doesn't equal self elementwise"
+ omask = getattr(other, '_mask', nomask)
+ if omask is nomask:
+ check = ndarray.__ne__(self.filled(0), other).view(type(self))
+ check._mask = self._mask
+ else:
+ odata = filled(other, 0)
+ check = ndarray.__ne__(self.filled(0), odata).view(type(self))
+ if self._mask is nomask:
+ check._mask = omask
+ else:
+ mask = mask_or(self._mask, omask)
+ if mask.dtype.names:
+ if mask.size > 1:
+ axis = 1
+ else:
+ axis = None
+ try:
+ mask = mask.view((bool_, len(self.dtype))).all(axis)
+ except ValueError:
+ mask = np.all([[f[n].all() for n in mask.dtype.names]
+ for f in mask], axis=axis)
+ check._mask = mask
+ return check
+ #
def __add__(self, other):
"Add other to self, and return a new masked array."
return add(self, other)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index dbb56319f..9ac4342e6 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -474,6 +474,16 @@ class TestMaskedArray(TestCase):
np.array([(1, '1', 1.)], dtype=flexi.dtype))
+ def test_filled_w_nested_dtype(self):
+ "Test filled w/ nested dtype"
+ ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
+ a = array([(1, (1, 1)), (2, (2, 2))],
+ mask=[(0, (1, 0)), (0, (0, 1))], dtype=ndtype)
+ test = a.filled(0)
+ control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype)
+ assert_equal(test, control)
+
+
def test_optinfo_propagation(self):
"Checks that _optinfo dictionary isn't back-propagated"
x = array([1,2,3,], dtype=float)
@@ -884,6 +894,40 @@ class TestMaskedArrayArithmetic(TestCase):
self.failUnless(output[0] is masked)
+ def test_eq_on_structured(self):
+ "Test the equality of structured arrays"
+ ndtype = [('A', int), ('B', int)]
+ a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+ test = (a == a)
+ assert_equal(test, [True, True])
+ assert_equal(test.mask, [False, False])
+ b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
+ test = (a == b)
+ assert_equal(test, [False, True])
+ assert_equal(test.mask, [True, False])
+ b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+ test = (a == b)
+ assert_equal(test, [True, False])
+ assert_equal(test.mask, [False, False])
+
+
+ def test_ne_on_structured(self):
+ "Test the equality of structured arrays"
+ ndtype = [('A', int), ('B', int)]
+ a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+ test = (a != a)
+ assert_equal(test, [False, False])
+ assert_equal(test.mask, [False, False])
+ b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
+ test = (a != b)
+ assert_equal(test, [True, False])
+ assert_equal(test.mask, [True, False])
+ b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
+ test = (a != b)
+ assert_equal(test, [False, True])
+ assert_equal(test.mask, [False, False])
+
+
#------------------------------------------------------------------------------
class TestMaskedArrayAttributes(TestCase):