diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-26 11:30:23 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-28 22:23:14 +0100 |
commit | 17466ad1839718c091c629bb647e881b7922a148 (patch) | |
tree | 05335c2935024943fe6913f7d0982603b17de4e6 /numpy/core/numeric.py | |
parent | e3ed705e5d91b584e9191a20f3a4780d354271ff (diff) | |
download | numpy-17466ad1839718c091c629bb647e881b7922a148.tar.gz |
MAINT: Rename _validate_axis, and document it
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 58 |
1 files changed, 50 insertions, 8 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index afa10747c..632266310 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim)) + axis = asarray(normalize_axis_tuple(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1460,7 +1460,7 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: - axis = _validate_axis(axis, a.ndim, allow_duplicate=True) + axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( @@ -1542,15 +1542,57 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + + Used internally by multi-axis-checking logic. + + .. versionadded:: 1.13.0 + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If any axis provided is out of range + ValueError + If an axis is repeated + + See also + -------- + normalize_axis_index : normalizing a single scalar axis + """ try: axis = [operator.index(axis)] except TypeError: - axis = list(axis) - axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] + axis = tuple(axis) + axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) if not allow_duplicate and len(set(axis)) != len(axis): if argname: - raise ValueError('repeated axis in `%s` argument' % argname) + raise ValueError('repeated axis in `{}` argument'.format(argname)) else: raise ValueError('repeated axis') return axis @@ -1612,8 +1654,8 @@ def moveaxis(a, source, destination): a = asarray(a) transpose = a.transpose - source = _validate_axis(source, a.ndim, 'source') - destination = _validate_axis(destination, a.ndim, 'destination') + source = normalize_axis_tuple(source, a.ndim, 'source') + destination = normalize_axis_tuple(destination, a.ndim, 'destination') if len(source) != len(destination): raise ValueError('`source` and `destination` arguments must have ' 'the same number of elements') |