summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numeric.py58
-rw-r--r--numpy/lib/function_base.py4
-rw-r--r--numpy/ma/core.py4
-rw-r--r--numpy/ma/extras.py4
4 files changed, 56 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index afa10747c..632266310 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)
- axis = asarray(_validate_axis(axis, a.ndim))
+ axis = asarray(normalize_axis_tuple(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
if axis.size == 1:
@@ -1460,7 +1460,7 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
- axis = _validate_axis(axis, a.ndim, allow_duplicate=True)
+ axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
@@ -1542,15 +1542,57 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname=None, allow_duplicate=False):
+def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
+ """
+ Normalizes an axis argument into a tuple of non-negative integer axes.
+
+ This handles shorthands such as ``1`` and converts them to ``(1,)``,
+ as well as performing the handling of negative indices covered by
+ `normalize_axis_index`.
+
+ By default, this forbids axes from being specified multiple times.
+
+ Used internally by multi-axis-checking logic.
+
+ .. versionadded:: 1.13.0
+
+ Parameters
+ ----------
+ axis : int, iterable of int
+ The un-normalized index or indices of the axis.
+ ndim : int
+ The number of dimensions of the array that `axis` should be normalized
+ against.
+ argname : str, optional
+ A prefix to put before the error message, typically the name of the
+ argument.
+ allow_duplicate : bool, optional
+ If False, the default, disallow an axis from being specified twice.
+
+ Returns
+ -------
+ normalized_axes : tuple of int
+ The normalized axis index, such that `0 <= normalized_axis < ndim`
+
+ Raises
+ ------
+ AxisError
+ If any axis provided is out of range
+ ValueError
+ If an axis is repeated
+
+ See also
+ --------
+ normalize_axis_index : normalizing a single scalar axis
+ """
try:
axis = [operator.index(axis)]
except TypeError:
- axis = list(axis)
- axis = [normalize_axis_index(ax, ndim, argname) for ax in axis]
+ axis = tuple(axis)
+ axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
- raise ValueError('repeated axis in `%s` argument' % argname)
+ raise ValueError('repeated axis in `{}` argument'.format(argname))
else:
raise ValueError('repeated axis')
return axis
@@ -1612,8 +1654,8 @@ def moveaxis(a, source, destination):
a = asarray(a)
transpose = a.transpose
- source = _validate_axis(source, a.ndim, 'source')
- destination = _validate_axis(destination, a.ndim, 'destination')
+ source = normalize_axis_tuple(source, a.ndim, 'source')
+ destination = normalize_axis_tuple(destination, a.ndim, 'destination')
if len(source) != len(destination):
raise ValueError('`source` and `destination` arguments must have '
'the same number of elements')
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 5b3af311b..2e6de9ec7 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1680,7 +1680,7 @@ def gradient(f, *varargs, **kwargs):
if axes is None:
axes = tuple(range(N))
else:
- axes = _nx._validate_axis(axes, N)
+ axes = _nx.normalize_axis_tuple(axes, N)
len_axes = len(axes)
n = len(varargs)
@@ -3972,7 +3972,7 @@ def _ureduce(a, func, **kwargs):
if axis is not None:
keepdim = list(a.shape)
nd = a.ndim
- axis = _nx._validate_axis(axis, nd)
+ axis = _nx.normalize_axis_tuple(axis, nd)
for ax in axis:
keepdim[ax] = 1
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 6bc1bc623..6b1ea7806 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -43,7 +43,7 @@ from numpy.compat import (
)
from numpy import expand_dims as n_expand_dims
from numpy.core.multiarray import normalize_axis_index
-from numpy.core.numeric import _validate_axis
+from numpy.core.numeric import normalize_axis_tuple
if sys.version_info[0] >= 3:
@@ -4370,7 +4370,7 @@ class MaskedArray(ndarray):
return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size
- axes = _validate_axis(axis, self.ndim)
+ axes = normalize_axis_tuple(axis, self.ndim)
items = 1
for ax in axes:
items *= self.shape[ax]
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 8b37df902..4955d25eb 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -37,7 +37,7 @@ import numpy as np
from numpy import ndarray, array as nxarray
import numpy.core.umath as umath
from numpy.core.multiarray import normalize_axis_index
-from numpy.core.numeric import _validate_axis
+from numpy.core.numeric import normalize_axis_tuple
from numpy.lib.function_base import _ureduce
from numpy.lib.index_tricks import AxisConcatenator
@@ -820,7 +820,7 @@ def compress_nd(x, axis=None):
if axis is None:
axis = tuple(range(x.ndim))
else:
- axis = _validate_axis(axis, x.ndim)
+ axis = normalize_axis_tuple(axis, x.ndim)
# Nothing is masked: return x
if m is nomask or not m.any():