From aa0532f362dc2fd398a78a838f0cd2c95cad19d8 Mon Sep 17 00:00:00 2001 From: Jaime Fernandez Date: Sat, 16 May 2015 22:34:39 -0700 Subject: MANT: handling of axis in np.ma.compress_nd Modified the handling that np.ma.compress_nd does of its 'axis' argument, to more closely follow that of ufunc methods, including error messages for wrong values. --- numpy/ma/extras.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) (limited to 'numpy/ma/extras.py') diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 64a9844cf..3c10af24c 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -718,7 +718,7 @@ def median(a, axis=None, out=None, overwrite_input=False): #.............................................................................. def compress_nd(x, axis=None): """Supress slices from multiple dimensions which contain masked values. - + Parameters ---------- x : array_like, MaskedArray @@ -731,29 +731,27 @@ def compress_nd(x, axis=None): - If axis is a tuple of ints, those are the axes to supress slices from. - If axis is an int, then that is the only axis to supress slices from. - If axis is None, all axis are selected. - + Returns ------- compress_array : ndarray - The compressed array. - """ + The compressed array. + """ x = asarray(x) m = getmask(x) # Set axis to tuple of ints - if isinstance(axis, tuple): - axis = tuple(ax % x.ndim for ax in axis) - elif isinstance(axis, int): - axis = (axis % x.ndim,) + if isinstance(axis, int): + axis = (axis,) elif axis is None: axis = tuple(range(x.ndim)) - else: - raise ValueError('Invalid type for axis argument') + elif not isinstance(axis, tuple): + raise ValueError('Invalid type for axis argument') # Check axis input - for ax in axis: - if not (0 <= ax < x.ndim): - raise ValueError('axis %d is out of range' % ax) - if not len(axis) == len(set(axis)): - raise ValueError('axis cannot have dupliate entries') + axis = [ax + x.ndim if ax < 0 else ax for ax in axis] + if not all(0 <= ax < x.ndim for ax in axis): + raise ValueError("'axis' entry is out of bounds") + if len(axis) != len(set(axis)): + raise ValueError("duplicate value in 'axis'") # Nothing is masked: return x if m is nomask or not m.any(): return x._data @@ -766,7 +764,7 @@ def compress_nd(x, axis=None): axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim))) data = data[(slice(None),)*ax + (~m.any(axis=axes),)] return data - + def compress_rowcols(x, axis=None): """ Suppress the rows and/or columns of a 2-D array that contain -- cgit v1.2.1