summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-26 11:30:23 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-03-28 22:23:14 +0100
commit17466ad1839718c091c629bb647e881b7922a148 (patch)
tree05335c2935024943fe6913f7d0982603b17de4e6 /numpy/core/numeric.py
parente3ed705e5d91b584e9191a20f3a4780d354271ff (diff)
downloadnumpy-17466ad1839718c091c629bb647e881b7922a148.tar.gz
MAINT: Rename _validate_axis, and document it
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py58
1 files changed, 50 insertions, 8 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index afa10747c..632266310 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)
- axis = asarray(_validate_axis(axis, a.ndim))
+ axis = asarray(normalize_axis_tuple(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
if axis.size == 1:
@@ -1460,7 +1460,7 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
- axis = _validate_axis(axis, a.ndim, allow_duplicate=True)
+ axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
@@ -1542,15 +1542,57 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname=None, allow_duplicate=False):
+def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
+ """
+ Normalizes an axis argument into a tuple of non-negative integer axes.
+
+ This handles shorthands such as ``1`` and converts them to ``(1,)``,
+ as well as performing the handling of negative indices covered by
+ `normalize_axis_index`.
+
+ By default, this forbids axes from being specified multiple times.
+
+ Used internally by multi-axis-checking logic.
+
+ .. versionadded:: 1.13.0
+
+ Parameters
+ ----------
+ axis : int, iterable of int
+ The un-normalized index or indices of the axis.
+ ndim : int
+ The number of dimensions of the array that `axis` should be normalized
+ against.
+ argname : str, optional
+ A prefix to put before the error message, typically the name of the
+ argument.
+ allow_duplicate : bool, optional
+ If False, the default, disallow an axis from being specified twice.
+
+ Returns
+ -------
+ normalized_axes : tuple of int
+ The normalized axis index, such that `0 <= normalized_axis < ndim`
+
+ Raises
+ ------
+ AxisError
+ If any axis provided is out of range
+ ValueError
+ If an axis is repeated
+
+ See also
+ --------
+ normalize_axis_index : normalizing a single scalar axis
+ """
try:
axis = [operator.index(axis)]
except TypeError:
- axis = list(axis)
- axis = [normalize_axis_index(ax, ndim, argname) for ax in axis]
+ axis = tuple(axis)
+ axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
- raise ValueError('repeated axis in `%s` argument' % argname)
+ raise ValueError('repeated axis in `{}` argument'.format(argname))
else:
raise ValueError('repeated axis')
return axis
@@ -1612,8 +1654,8 @@ def moveaxis(a, source, destination):
a = asarray(a)
transpose = a.transpose
- source = _validate_axis(source, a.ndim, 'source')
- destination = _validate_axis(destination, a.ndim, 'destination')
+ source = normalize_axis_tuple(source, a.ndim, 'source')
+ destination = normalize_axis_tuple(destination, a.ndim, 'destination')
if len(source) != len(destination):
raise ValueError('`source` and `destination` arguments must have '
'the same number of elements')