summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-03-29 10:44:15 -0400
committerGitHub <noreply@github.com>2017-03-29 10:44:15 -0400
commitab9b15ef0a80ea002d12498554853544bb25ad57 (patch)
tree4fd196fde340c23643ac54cccc2716e35b8dc176 /numpy/core/numeric.py
parent31a8fd34f9acf313a8139b5ad621cb289bbbdb1c (diff)
parent17466ad1839718c091c629bb647e881b7922a148 (diff)
downloadnumpy-ab9b15ef0a80ea002d12498554853544bb25ad57.tar.gz
Merge pull request #8843 from eric-wieser/more-AxisError
MAINT: Use AxisError in more places
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py81
1 files changed, 59 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 01dd46c3c..632266310 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)
- axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
+ axis = asarray(normalize_axis_tuple(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
if axis.size == 1:
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
+ axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
- if -a.ndim <= ax < a.ndim:
- shifts[ax % a.ndim] += sh
- else:
- raise ValueError("'axis' entry is out of bounds")
+ shifts[ax] += sh
rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
@@ -1544,17 +1542,59 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname):
+def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
+ """
+ Normalizes an axis argument into a tuple of non-negative integer axes.
+
+ This handles shorthands such as ``1`` and converts them to ``(1,)``,
+ as well as performing the handling of negative indices covered by
+ `normalize_axis_index`.
+
+ By default, this forbids axes from being specified multiple times.
+
+ Used internally by multi-axis-checking logic.
+
+ .. versionadded:: 1.13.0
+
+ Parameters
+ ----------
+ axis : int, iterable of int
+ The un-normalized index or indices of the axis.
+ ndim : int
+ The number of dimensions of the array that `axis` should be normalized
+ against.
+ argname : str, optional
+ A prefix to put before the error message, typically the name of the
+ argument.
+ allow_duplicate : bool, optional
+ If False, the default, disallow an axis from being specified twice.
+
+ Returns
+ -------
+ normalized_axes : tuple of int
+ The normalized axis index, such that `0 <= normalized_axis < ndim`
+
+ Raises
+ ------
+ AxisError
+ If any axis provided is out of range
+ ValueError
+ If an axis is repeated
+
+ See also
+ --------
+ normalize_axis_index : normalizing a single scalar axis
+ """
try:
axis = [operator.index(axis)]
except TypeError:
- axis = list(axis)
- axis = [a + ndim if a < 0 else a for a in axis]
- if not builtins.all(0 <= a < ndim for a in axis):
- raise AxisError('invalid axis for this array in `%s` argument' %
- argname)
- if len(set(axis)) != len(axis):
- raise ValueError('repeated axis in `%s` argument' % argname)
+ axis = tuple(axis)
+ axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
+ if not allow_duplicate and len(set(axis)) != len(axis):
+ if argname:
+ raise ValueError('repeated axis in `{}` argument'.format(argname))
+ else:
+ raise ValueError('repeated axis')
return axis
@@ -1614,8 +1654,8 @@ def moveaxis(a, source, destination):
a = asarray(a)
transpose = a.transpose
- source = _validate_axis(source, a.ndim, 'source')
- destination = _validate_axis(destination, a.ndim, 'destination')
+ source = normalize_axis_tuple(source, a.ndim, 'source')
+ destination = normalize_axis_tuple(destination, a.ndim, 'destination')
if len(source) != len(destination):
raise ValueError('`source` and `destination` arguments must have '
'the same number of elements')
@@ -1752,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
a = asarray(a)
b = asarray(b)
# Check axisa and axisb are within bounds
- axis_msg = "'axis{0}' out of bounds"
- if axisa < -a.ndim or axisa >= a.ndim:
- raise ValueError(axis_msg.format('a'))
- if axisb < -b.ndim or axisb >= b.ndim:
- raise ValueError(axis_msg.format('b'))
+ axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa')
+ axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb')
+
# Move working axis to the end of the shape
a = rollaxis(a, axisa, a.ndim)
b = rollaxis(b, axisb, b.ndim)
@@ -1770,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
# Check axisc is within bounds
- if axisc < -len(shape) or axisc >= len(shape):
- raise ValueError(axis_msg.format('c'))
+ axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc')
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)