diff options
author | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-16 22:34:39 -0700 |
---|---|---|
committer | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 09:45:02 -0700 |
commit | aa0532f362dc2fd398a78a838f0cd2c95cad19d8 (patch) | |
tree | 9d71f6b2c87f20d20470f41e51b23ede794b694f /numpy/ma/extras.py | |
parent | 0a02b82ed72f0268875bfc6c70d1e8a8dad6c644 (diff) | |
download | numpy-aa0532f362dc2fd398a78a838f0cd2c95cad19d8.tar.gz |
MANT: handling of axis in np.ma.compress_nd
Modified the handling that np.ma.compress_nd does of its 'axis'
argument, to more closely follow that of ufunc methods, including
error messages for wrong values.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r-- | numpy/ma/extras.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 64a9844cf..3c10af24c 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -718,7 +718,7 @@ def median(a, axis=None, out=None, overwrite_input=False): #.............................................................................. def compress_nd(x, axis=None): """Supress slices from multiple dimensions which contain masked values. - + Parameters ---------- x : array_like, MaskedArray @@ -731,29 +731,27 @@ def compress_nd(x, axis=None): - If axis is a tuple of ints, those are the axes to supress slices from. - If axis is an int, then that is the only axis to supress slices from. - If axis is None, all axis are selected. - + Returns ------- compress_array : ndarray - The compressed array. - """ + The compressed array. + """ x = asarray(x) m = getmask(x) # Set axis to tuple of ints - if isinstance(axis, tuple): - axis = tuple(ax % x.ndim for ax in axis) - elif isinstance(axis, int): - axis = (axis % x.ndim,) + if isinstance(axis, int): + axis = (axis,) elif axis is None: axis = tuple(range(x.ndim)) - else: - raise ValueError('Invalid type for axis argument') + elif not isinstance(axis, tuple): + raise ValueError('Invalid type for axis argument') # Check axis input - for ax in axis: - if not (0 <= ax < x.ndim): - raise ValueError('axis %d is out of range' % ax) - if not len(axis) == len(set(axis)): - raise ValueError('axis cannot have dupliate entries') + axis = [ax + x.ndim if ax < 0 else ax for ax in axis] + if not all(0 <= ax < x.ndim for ax in axis): + raise ValueError("'axis' entry is out of bounds") + if len(axis) != len(set(axis)): + raise ValueError("duplicate value in 'axis'") # Nothing is masked: return x if m is nomask or not m.any(): return x._data @@ -766,7 +764,7 @@ def compress_nd(x, axis=None): axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim))) data = data[(slice(None),)*ax + (~m.any(axis=axes),)] return data - + def compress_rowcols(x, axis=None): """ Suppress the rows and/or columns of a 2-D array that contain |