diff options
author | pierregm <pierregm@localhost> | 2008-12-01 17:56:58 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2008-12-01 17:56:58 +0000 |
commit | 77f95a1398ddad2d031aaf8e4db344ffc93d1d73 (patch) | |
tree | b83ef742a334f473c6c6eca9ed1fd9dd63a5b1e0 /numpy/ma/core.py | |
parent | 7933d532fea5a0e75e6e415e6ab9d6bc7c585089 (diff) | |
download | numpy-77f95a1398ddad2d031aaf8e4db344ffc93d1d73.tar.gz |
* added flatten_mask to collapse masks w/ (nested) flexible types.
* fixed __getitem__ on arrays w/ nested dtype
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r-- | numpy/ma/core.py | 71 |
1 files changed, 66 insertions, 5 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 590421594..c8bc425f4 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -32,8 +32,8 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray', 'count', 'cumprod', 'cumsum', 'default_fill_value', 'diag', 'diagonal', 'divide', 'dump', 'dumps', 'empty', 'empty_like', 'equal', 'exp', 'expand_dims', - 'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid', - 'frombuffer', 'fromfunction', + 'fabs', 'flatten_mask', 'fmod', 'filled', 'floor', 'floor_divide', + 'fix_invalid', 'frombuffer', 'fromfunction', 'getdata','getmask', 'getmaskarray', 'greater', 'greater_equal', 'harden_mask', 'hypot', 'identity', 'ids', 'indices', 'inner', 'innerproduct', @@ -967,6 +967,61 @@ def mask_or (m1, m2, copy=False, shrink=True): return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink) +def flatten_mask(mask): + """ + Returns a completely flattened version of the mask, where nested fields + are collapsed. + + Parameters + ---------- + mask : array_like + Array of booleans + + Returns + ------- + flattened_mask : ndarray + Boolean array. + + Examples + -------- + >>> mask = np.array([0, 0, 1], dtype=np.bool) + >>> flatten_mask(mask) + array([False, False, True], dtype=bool) + >>> mask = np.array([(0, 0), (0, 1)], dtype=[('a', bool), ('b', bool)]) + >>> flatten_mask(mask) + array([False, False, False, True], dtype=bool) + >>> mdtype = [('a', bool), ('b', [('ba', bool), ('bb', bool)])] + >>> mask = np.array([(0, (0, 0)), (0, (0, 1))], dtype=mdtype) + >>> flatten_mask(mask) + array([False, False, False, False, False, True], dtype=bool) + + """ + # + def _flatmask(mask): + "Flatten the mask and returns a (maybe nested) sequence of booleans." + mnames = mask.dtype.names + if mnames: + return [flatten_mask(mask[name]) for name in mnames] + else: + return mask + # + def _flatsequence(sequence): + "Generates a flattened version of the sequence." + try: + for element in sequence: + if hasattr(element, '__iter__'): + for f in _flatsequence(element): + yield f + else: + yield element + except TypeError: + yield sequence + # + mask = np.asarray(mask) + flattened = _flatsequence(_flatmask(mask)) + return np.array([_ for _ in flattened], dtype=bool) + + #####-------------------------------------------------------------------------- #--- --- Masking functions --- #####-------------------------------------------------------------------------- @@ -1350,7 +1405,7 @@ class MaskedArray(ndarray): # Process data............ _data = np.array(data, dtype=dtype, copy=copy, subok=True, ndmin=ndmin) _baseclass = getattr(data, '_baseclass', type(_data)) - # Check that we'ew not erasing the mask.......... + # Check that we're not erasing the mask.......... if isinstance(data, MaskedArray) and (data.shape != _data.shape): copy = True # Careful, cls might not always be MaskedArray... @@ -1382,7 +1437,13 @@ class MaskedArray(ndarray): _data._mask = np.zeros(_data.shape, dtype=mdtype) # Check whether we missed something elif isinstance(data, (tuple,list)): - mask = np.array([getmaskarray(m) for m in data], dtype=mdtype) + try: + # If data is a sequence of masked array + mask = np.array([getmaskarray(m) for m in data], + dtype=mdtype) + except ValueError: + # If data is nested + mask = nomask # Force shrinking of the mask if needed (and possible) if (mdtype == MaskType) and mask.any(): _data._mask = mask @@ -1624,7 +1685,7 @@ class MaskedArray(ndarray): # A record ................ if isinstance(dout, np.void): mask = _mask[indx] - if mask.view((bool, len(mask.dtype))).any(): + if flatten_mask(mask).any(): dout = masked_array(dout, mask=mask) else: return dout |