From 4feb97e66000b17baa15b8486a7c1f9608865e78 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 26 Mar 2017 10:22:21 +0100 Subject: MAINT: have _validate_axis delegate to normalize_axis_index They were only separate before due to a requirement of a better error message --- numpy/core/numeric.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 01dd46c3c..eabf7962a 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim, 'axis')) + axis = asarray(_validate_axis(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1544,17 +1544,17 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname): +def _validate_axis(axis, ndim, argname=None): try: axis = [operator.index(axis)] except TypeError: axis = list(axis) - axis = [a + ndim if a < 0 else a for a in axis] - if not builtins.all(0 <= a < ndim for a in axis): - raise AxisError('invalid axis for this array in `%s` argument' % - argname) + axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] if len(set(axis)) != len(axis): - raise ValueError('repeated axis in `%s` argument' % argname) + if argname: + raise ValueError('repeated axis in `%s` argument' % argname) + else: + raise ValueError('repeated axis') return axis -- cgit v1.2.1 From 09d4d35bb979dd0d5d0781946f59f362765eca66 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 26 Mar 2017 10:23:53 +0100 Subject: MAINT: Use normalize_axis_index in cross Previously we did not, as this wouldn't say which axis argument was the cause of the problem. --- numpy/core/numeric.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index eabf7962a..e1352ee6c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1752,11 +1752,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): a = asarray(a) b = asarray(b) # Check axisa and axisb are within bounds - axis_msg = "'axis{0}' out of bounds" - if axisa < -a.ndim or axisa >= a.ndim: - raise ValueError(axis_msg.format('a')) - if axisb < -b.ndim or axisb >= b.ndim: - raise ValueError(axis_msg.format('b')) + axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa') + axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb') + # Move working axis to the end of the shape a = rollaxis(a, axisa, a.ndim) b = rollaxis(b, axisb, b.ndim) @@ -1770,8 +1768,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if a.shape[-1] == 3 or b.shape[-1] == 3: shape += (3,) # Check axisc is within bounds - if axisc < -len(shape) or axisc >= len(shape): - raise ValueError(axis_msg.format('c')) + axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc') dtype = promote_types(a.dtype, b.dtype) cp = empty(shape, dtype) -- cgit v1.2.1 From b6850e9c00055666fe9bd5f984750743ffcb85d2 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 26 Mar 2017 10:50:54 +0100 Subject: MAINT: use normalize_axis_index in np.roll Interestingly np.roll allows axes to be repeated. --- numpy/core/numeric.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index e1352ee6c..afa10747c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: + axis = _validate_axis(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( "'shift' and 'axis' should be scalars or 1D sequences") shifts = {ax: 0 for ax in range(a.ndim)} for sh, ax in broadcasted: - if -a.ndim <= ax < a.ndim: - shifts[ax % a.ndim] += sh - else: - raise ValueError("'axis' entry is out of bounds") + shifts[ax] += sh rolls = [((slice(None), slice(None)),)] * a.ndim for ax, offset in shifts.items(): @@ -1544,13 +1542,13 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None): +def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): try: axis = [operator.index(axis)] except TypeError: axis = list(axis) axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] - if len(set(axis)) != len(axis): + if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `%s` argument' % argname) else: -- cgit v1.2.1 From 17466ad1839718c091c629bb647e881b7922a148 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 26 Mar 2017 11:30:23 +0100 Subject: MAINT: Rename _validate_axis, and document it --- numpy/core/numeric.py | 58 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 8 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index afa10747c..632266310 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim)) + axis = asarray(normalize_axis_tuple(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1460,7 +1460,7 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: - axis = _validate_axis(axis, a.ndim, allow_duplicate=True) + axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( @@ -1542,15 +1542,57 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + + Used internally by multi-axis-checking logic. + + .. versionadded:: 1.13.0 + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If any axis provided is out of range + ValueError + If an axis is repeated + + See also + -------- + normalize_axis_index : normalizing a single scalar axis + """ try: axis = [operator.index(axis)] except TypeError: - axis = list(axis) - axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] + axis = tuple(axis) + axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) if not allow_duplicate and len(set(axis)) != len(axis): if argname: - raise ValueError('repeated axis in `%s` argument' % argname) + raise ValueError('repeated axis in `{}` argument'.format(argname)) else: raise ValueError('repeated axis') return axis @@ -1612,8 +1654,8 @@ def moveaxis(a, source, destination): a = asarray(a) transpose = a.transpose - source = _validate_axis(source, a.ndim, 'source') - destination = _validate_axis(destination, a.ndim, 'destination') + source = normalize_axis_tuple(source, a.ndim, 'source') + destination = normalize_axis_tuple(destination, a.ndim, 'destination') if len(source) != len(destination): raise ValueError('`source` and `destination` arguments must have ' 'the same number of elements') -- cgit v1.2.1