summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-12-01 17:56:58 +0000
committerpierregm <pierregm@localhost>2008-12-01 17:56:58 +0000
commit77f95a1398ddad2d031aaf8e4db344ffc93d1d73 (patch)
treeb83ef742a334f473c6c6eca9ed1fd9dd63a5b1e0 /numpy/ma/core.py
parent7933d532fea5a0e75e6e415e6ab9d6bc7c585089 (diff)
downloadnumpy-77f95a1398ddad2d031aaf8e4db344ffc93d1d73.tar.gz
* added flatten_mask to collapse masks w/ (nested) flexible types.
* fixed __getitem__ on arrays w/ nested dtype
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py71
1 files changed, 66 insertions, 5 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 590421594..c8bc425f4 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -32,8 +32,8 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray',
'count', 'cumprod', 'cumsum',
'default_fill_value', 'diag', 'diagonal', 'divide', 'dump', 'dumps',
'empty', 'empty_like', 'equal', 'exp', 'expand_dims',
- 'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
- 'frombuffer', 'fromfunction',
+ 'fabs', 'flatten_mask', 'fmod', 'filled', 'floor', 'floor_divide',
+ 'fix_invalid', 'frombuffer', 'fromfunction',
'getdata','getmask', 'getmaskarray', 'greater', 'greater_equal',
'harden_mask', 'hypot',
'identity', 'ids', 'indices', 'inner', 'innerproduct',
@@ -967,6 +967,61 @@ def mask_or (m1, m2, copy=False, shrink=True):
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
+def flatten_mask(mask):
+ """
+ Returns a completely flattened version of the mask, where nested fields
+ are collapsed.
+
+ Parameters
+ ----------
+ mask : array_like
+ Array of booleans
+
+ Returns
+ -------
+ flattened_mask : ndarray
+ Boolean array.
+
+ Examples
+ --------
+ >>> mask = np.array([0, 0, 1], dtype=np.bool)
+ >>> flatten_mask(mask)
+ array([False, False, True], dtype=bool)
+ >>> mask = np.array([(0, 0), (0, 1)], dtype=[('a', bool), ('b', bool)])
+ >>> flatten_mask(mask)
+ array([False, False, False, True], dtype=bool)
+ >>> mdtype = [('a', bool), ('b', [('ba', bool), ('bb', bool)])]
+ >>> mask = np.array([(0, (0, 0)), (0, (0, 1))], dtype=mdtype)
+ >>> flatten_mask(mask)
+ array([False, False, False, False, False, True], dtype=bool)
+
+ """
+ #
+ def _flatmask(mask):
+ "Flatten the mask and returns a (maybe nested) sequence of booleans."
+ mnames = mask.dtype.names
+ if mnames:
+ return [flatten_mask(mask[name]) for name in mnames]
+ else:
+ return mask
+ #
+ def _flatsequence(sequence):
+ "Generates a flattened version of the sequence."
+ try:
+ for element in sequence:
+ if hasattr(element, '__iter__'):
+ for f in _flatsequence(element):
+ yield f
+ else:
+ yield element
+ except TypeError:
+ yield sequence
+ #
+ mask = np.asarray(mask)
+ flattened = _flatsequence(_flatmask(mask))
+ return np.array([_ for _ in flattened], dtype=bool)
+
+
#####--------------------------------------------------------------------------
#--- --- Masking functions ---
#####--------------------------------------------------------------------------
@@ -1350,7 +1405,7 @@ class MaskedArray(ndarray):
# Process data............
_data = np.array(data, dtype=dtype, copy=copy, subok=True, ndmin=ndmin)
_baseclass = getattr(data, '_baseclass', type(_data))
- # Check that we'ew not erasing the mask..........
+ # Check that we're not erasing the mask..........
if isinstance(data, MaskedArray) and (data.shape != _data.shape):
copy = True
# Careful, cls might not always be MaskedArray...
@@ -1382,7 +1437,13 @@ class MaskedArray(ndarray):
_data._mask = np.zeros(_data.shape, dtype=mdtype)
# Check whether we missed something
elif isinstance(data, (tuple,list)):
- mask = np.array([getmaskarray(m) for m in data], dtype=mdtype)
+ try:
+ # If data is a sequence of masked array
+ mask = np.array([getmaskarray(m) for m in data],
+ dtype=mdtype)
+ except ValueError:
+ # If data is nested
+ mask = nomask
# Force shrinking of the mask if needed (and possible)
if (mdtype == MaskType) and mask.any():
_data._mask = mask
@@ -1624,7 +1685,7 @@ class MaskedArray(ndarray):
# A record ................
if isinstance(dout, np.void):
mask = _mask[indx]
- if mask.view((bool, len(mask.dtype))).any():
+ if flatten_mask(mask).any():
dout = masked_array(dout, mask=mask)
else:
return dout