diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 58 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 4 | ||||
-rw-r--r-- | numpy/ma/core.py | 4 | ||||
-rw-r--r-- | numpy/ma/extras.py | 4 |
4 files changed, 56 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index afa10747c..632266310 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim)) + axis = asarray(normalize_axis_tuple(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1460,7 +1460,7 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: - axis = _validate_axis(axis, a.ndim, allow_duplicate=True) + axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( @@ -1542,15 +1542,57 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + + Used internally by multi-axis-checking logic. + + .. versionadded:: 1.13.0 + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If any axis provided is out of range + ValueError + If an axis is repeated + + See also + -------- + normalize_axis_index : normalizing a single scalar axis + """ try: axis = [operator.index(axis)] except TypeError: - axis = list(axis) - axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] + axis = tuple(axis) + axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) if not allow_duplicate and len(set(axis)) != len(axis): if argname: - raise ValueError('repeated axis in `%s` argument' % argname) + raise ValueError('repeated axis in `{}` argument'.format(argname)) else: raise ValueError('repeated axis') return axis @@ -1612,8 +1654,8 @@ def moveaxis(a, source, destination): a = asarray(a) transpose = a.transpose - source = _validate_axis(source, a.ndim, 'source') - destination = _validate_axis(destination, a.ndim, 'destination') + source = normalize_axis_tuple(source, a.ndim, 'source') + destination = normalize_axis_tuple(destination, a.ndim, 'destination') if len(source) != len(destination): raise ValueError('`source` and `destination` arguments must have ' 'the same number of elements') diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 5b3af311b..2e6de9ec7 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1680,7 +1680,7 @@ def gradient(f, *varargs, **kwargs): if axes is None: axes = tuple(range(N)) else: - axes = _nx._validate_axis(axes, N) + axes = _nx.normalize_axis_tuple(axes, N) len_axes = len(axes) n = len(varargs) @@ -3972,7 +3972,7 @@ def _ureduce(a, func, **kwargs): if axis is not None: keepdim = list(a.shape) nd = a.ndim - axis = _nx._validate_axis(axis, nd) + axis = _nx.normalize_axis_tuple(axis, nd) for ax in axis: keepdim[ax] = 1 diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 6bc1bc623..6b1ea7806 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -43,7 +43,7 @@ from numpy.compat import ( ) from numpy import expand_dims as n_expand_dims from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import _validate_axis +from numpy.core.numeric import normalize_axis_tuple if sys.version_info[0] >= 3: @@ -4370,7 +4370,7 @@ class MaskedArray(ndarray): return np.array(self.size, dtype=np.intp, ndmin=self.ndim) return self.size - axes = _validate_axis(axis, self.ndim) + axes = normalize_axis_tuple(axis, self.ndim) items = 1 for ax in axes: items *= self.shape[ax] diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 8b37df902..4955d25eb 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -37,7 +37,7 @@ import numpy as np from numpy import ndarray, array as nxarray import numpy.core.umath as umath from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import _validate_axis +from numpy.core.numeric import normalize_axis_tuple from numpy.lib.function_base import _ureduce from numpy.lib.index_tricks import AxisConcatenator @@ -820,7 +820,7 @@ def compress_nd(x, axis=None): if axis is None: axis = tuple(range(x.ndim)) else: - axis = _validate_axis(axis, x.ndim) + axis = normalize_axis_tuple(axis, x.ndim) # Nothing is masked: return x if m is nomask or not m.any(): |