summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-26 10:52:54 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-03-28 20:44:47 +0100
commit539d4f7ef561ac86ea4f3b81bf1eb9b3ac03b67f (patch)
tree16866bdcdcb067dcd9ad3d39d0297b67ccb1b1d2 /numpy/ma/extras.py
parentb6850e9c00055666fe9bd5f984750743ffcb85d2 (diff)
downloadnumpy-539d4f7ef561ac86ea4f3b81bf1eb9b3ac03b67f.tar.gz
MAINT: Reuse _validate_axis in ma
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 697565251..8b37df902 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -37,6 +37,7 @@ import numpy as np
from numpy import ndarray, array as nxarray
import numpy.core.umath as umath
from numpy.core.multiarray import normalize_axis_index
+from numpy.core.numeric import _validate_axis
from numpy.lib.function_base import _ureduce
from numpy.lib.index_tricks import AxisConcatenator
@@ -816,18 +817,11 @@ def compress_nd(x, axis=None):
x = asarray(x)
m = getmask(x)
# Set axis to tuple of ints
- if isinstance(axis, int):
- axis = (axis,)
- elif axis is None:
+ if axis is None:
axis = tuple(range(x.ndim))
- elif not isinstance(axis, tuple):
- raise ValueError('Invalid type for axis argument')
- # Check axis input
- axis = [ax + x.ndim if ax < 0 else ax for ax in axis]
- if not all(0 <= ax < x.ndim for ax in axis):
- raise ValueError("'axis' entry is out of bounds")
- if len(axis) != len(set(axis)):
- raise ValueError("duplicate value in 'axis'")
+ else:
+ axis = _validate_axis(axis, x.ndim)
+
# Nothing is masked: return x
if m is nomask or not m.any():
return x._data