summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorJaime Fernandez <jaime.frio@gmail.com>2015-05-16 22:34:39 -0700
committerJaime Fernandez <jaime.frio@gmail.com>2015-05-17 09:45:02 -0700
commitaa0532f362dc2fd398a78a838f0cd2c95cad19d8 (patch)
tree9d71f6b2c87f20d20470f41e51b23ede794b694f /numpy/ma/extras.py
parent0a02b82ed72f0268875bfc6c70d1e8a8dad6c644 (diff)
downloadnumpy-aa0532f362dc2fd398a78a838f0cd2c95cad19d8.tar.gz
MANT: handling of axis in np.ma.compress_nd
Modified the handling that np.ma.compress_nd does of its 'axis' argument, to more closely follow that of ufunc methods, including error messages for wrong values.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 64a9844cf..3c10af24c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -718,7 +718,7 @@ def median(a, axis=None, out=None, overwrite_input=False):
#..............................................................................
def compress_nd(x, axis=None):
"""Supress slices from multiple dimensions which contain masked values.
-
+
Parameters
----------
x : array_like, MaskedArray
@@ -731,29 +731,27 @@ def compress_nd(x, axis=None):
- If axis is a tuple of ints, those are the axes to supress slices from.
- If axis is an int, then that is the only axis to supress slices from.
- If axis is None, all axis are selected.
-
+
Returns
-------
compress_array : ndarray
- The compressed array.
- """
+ The compressed array.
+ """
x = asarray(x)
m = getmask(x)
# Set axis to tuple of ints
- if isinstance(axis, tuple):
- axis = tuple(ax % x.ndim for ax in axis)
- elif isinstance(axis, int):
- axis = (axis % x.ndim,)
+ if isinstance(axis, int):
+ axis = (axis,)
elif axis is None:
axis = tuple(range(x.ndim))
- else:
- raise ValueError('Invalid type for axis argument')
+ elif not isinstance(axis, tuple):
+ raise ValueError('Invalid type for axis argument')
# Check axis input
- for ax in axis:
- if not (0 <= ax < x.ndim):
- raise ValueError('axis %d is out of range' % ax)
- if not len(axis) == len(set(axis)):
- raise ValueError('axis cannot have dupliate entries')
+ axis = [ax + x.ndim if ax < 0 else ax for ax in axis]
+ if not all(0 <= ax < x.ndim for ax in axis):
+ raise ValueError("'axis' entry is out of bounds")
+ if len(axis) != len(set(axis)):
+ raise ValueError("duplicate value in 'axis'")
# Nothing is masked: return x
if m is nomask or not m.any():
return x._data
@@ -766,7 +764,7 @@ def compress_nd(x, axis=None):
axes = tuple(list(range(ax)) + list(range(ax + 1, x.ndim)))
data = data[(slice(None),)*ax + (~m.any(axis=axes),)]
return data
-
+
def compress_rowcols(x, axis=None):
"""
Suppress the rows and/or columns of a 2-D array that contain