diff options
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r-- | numpy/ma/extras.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 64a9844cf..3c10af24c 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -718,7 +718,7 @@ def median(a, axis=None, out=None, overwrite_input=False): #.............................................................................. def compress_nd(x, axis=None): """Supress slices from multiple dimensions which contain masked values. - + Parameters ---------- x : array_like, MaskedArray @@ -731,29 +731,27 @@ def compress_nd(x, axis=None): - If axis is a tuple of ints, those are the axes to supress slices from. - If axis is an int, then that is the only axis to supress slices from. - If axis is None, all axis are selected. - + Returns ------- compress_array : ndarray - The compressed array. - """ + The compressed array. + """ x = asarray(x) m = getmask(x) # Set axis to tuple of ints - if isinstance(axis, tuple): - axis = tuple(ax % x.ndim for ax in axis) - elif isinstance(axis, int): - axis = (axis % x.ndim,) + if isinstance(axis, int): + axis = (axis,) elif axis is None: axis = tuple(range(x.ndim)) - else: - raise ValueError('Invalid type for axis argument') + elif not isinstance(axis, tuple): + raise ValueError('Invalid type for axis argument') # Check axis input - for ax in axis: - if not (0 <= ax < x.ndim): - raise ValueError('axis %d is out of range' % ax) - if not len(axis) == len(set(axis)): - raise ValueError('axis cannot have dupliate entries') + axis = [ax + x.ndim if ax < 0 else ax for ax in axis] + if not all(0 <= ax < x.ndim for ax in axis): + raise ValueError("'axis' entry is out of bounds") + if len(axis) != len(set(axis)): + raise ValueError("duplicate value in 'axis'") # Nothing is masked: return x if m is nomask or not m.any(): return x._data @@ -766,7 +764,7 @@ def compress_nd(x, axis=None): axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim))) data = data[(slice(None),)*ax + (~m.any(axis=axes),)] return data - + def compress_rowcols(x, axis=None): """ Suppress the rows and/or columns of a 2-D array that contain |