summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorDaniel da Silva <mail@danieldasilva.org>2015-04-03 21:29:24 -0400
committerDaniel da Silva <mail@danieldasilva.org>2015-05-03 22:18:18 -0400
commit883d052e3eb9b45a4bb87e7e84f487e0c9e5c882 (patch)
tree201f09008b172571f7e478f3fb9eeadca01def20 /numpy/ma/extras.py
parent147c60f83f401037ff29593826d2c5729a73c2c5 (diff)
downloadnumpy-883d052e3eb9b45a4bb87e7e84f487e0c9e5c882.tar.gz
ENH: Introduce np.ma.compress_nd(), generalizes np.ma.compress_rowcols()
Provides a way to supress slices along an abitrary tuple of dimensions.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py85
1 files changed, 63 insertions, 22 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index d389099ae..51064e831 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -18,8 +18,8 @@ __date__ = '$Date: 2007-10-29 17:18:13 +0200 (Mon, 29 Oct 2007) $'
__all__ = ['apply_along_axis', 'apply_over_axes', 'atleast_1d', 'atleast_2d',
'atleast_3d', 'average',
'clump_masked', 'clump_unmasked', 'column_stack', 'compress_cols',
- 'compress_rowcols', 'compress_rows', 'count_masked', 'corrcoef',
- 'cov',
+ 'compress_nd', 'compress_rowcols', 'compress_rows', 'count_masked',
+ 'corrcoef', 'cov',
'diagflat', 'dot', 'dstack',
'ediff1d',
'flatnotmasked_contiguous', 'flatnotmasked_edges',
@@ -716,6 +716,57 @@ def median(a, axis=None, out=None, overwrite_input=False):
#..............................................................................
+def compress_nd(x, axis=None):
+ """Supress slices from multiple dimensions which contain masked values.
+
+ Parameters
+ ----------
+ x : array_like, MaskedArray
+ The array to operate on. If not a MaskedArray instance (or if no array
+ elements are masked, `x` is interpreted as a MaskedArray with `mask`
+ set to `nomask`.
+ axis : tuple of ints or int, optional
+ Which dimensions to supress slices from can be configured with this
+ parameter.
+ - If axis is a tuple of ints, those are the axes to supress slices from.
+ - If axis is an int, then that is the only axis to supress slices from.
+ - If axis is None, all axis are selected.
+
+ Returns
+ -------
+ compress_array : ndarray
+ The compressed array.
+ """
+ x = asarray(x)
+ m = getmask(x)
+ # Set axis to tuple of ints
+ if isinstance(axis, tuple):
+ axis = tuple(ax % x.ndim for ax in axis)
+ elif isinstance(axis, int):
+ axis = (axis % x.ndim,)
+ elif axis is None:
+ axis = tuple(range(x.ndim))
+ else:
+ raise ValueError('Invalid type for axis argument')
+ # Check axis input
+ for ax in axis:
+ if not (0 <= ax < x.ndim):
+ raise ValueError('axis %d is out of range' % ax)
+ if not len(axis) == len(set(axis)):
+ raise ValueError('axis cannot have dupliate entries')
+ # Nothing is masked: return x
+ if m is nomask or not m.any():
+ return x._data
+ # All is masked: return empty
+ if m.all():
+ return nxarray([])
+ # Filter elements through boolean indexing
+ data = x._data
+ for ax in axis:
+ axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim)))
+ data = data[(slice(None),)*ax + (~m.any(axis=axes),)]
+ return data
+
def compress_rowcols(x, axis=None):
"""
Suppress the rows and/or columns of a 2-D array that contain
@@ -767,26 +818,10 @@ def compress_rowcols(x, axis=None):
[7, 8]])
"""
- x = asarray(x)
- if x.ndim != 2:
- raise NotImplementedError("compress2d works for 2D arrays only.")
- m = getmask(x)
- # Nothing is masked: return x
- if m is nomask or not m.any():
- return x._data
- # All is masked: return empty
- if m.all():
- return nxarray([])
- # Builds a list of rows/columns indices
- (idxr, idxc) = (list(range(len(x))), list(range(x.shape[1])))
- masked = m.nonzero()
- if not axis:
- for i in np.unique(masked[0]):
- idxr.remove(i)
- if axis in [None, 1, -1]:
- for j in np.unique(masked[1]):
- idxc.remove(j)
- return x._data[idxr][:, idxc]
+ if asarray(x).ndim != 2:
+ raise NotImplementedError("compress_rowcols works for 2D arrays only.")
+ return compress_nd(x, axis=axis)
+
def compress_rows(a):
"""
@@ -800,6 +835,9 @@ def compress_rows(a):
extras.compress_rowcols
"""
+ a = asarray(a)
+ if a.ndim != 2:
+ raise NotImplementedError("compress_rows works for 2D arrays only.")
return compress_rowcols(a, 0)
def compress_cols(a):
@@ -814,6 +852,9 @@ def compress_cols(a):
extras.compress_rowcols
"""
+ a = asarray(a)
+ if a.ndim != 2:
+ raise NotImplementedError("compress_cols works for 2D arrays only.")
return compress_rowcols(a, 1)
def mask_rowcols(a, axis=None):