summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-26 10:22:21 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-03-28 20:44:36 +0100
commit4feb97e66000b17baa15b8486a7c1f9608865e78 (patch)
tree0d2ab43d91bc01706a366fea710536ff1775a3ca /numpy
parent7bf911585d7f9d56480ee1b3d70fab39fc588350 (diff)
downloadnumpy-4feb97e66000b17baa15b8486a7c1f9608865e78.tar.gz
MAINT: have _validate_axis delegate to normalize_axis_index
They were only separate before due to a requirement of a better error message
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numeric.py14
-rw-r--r--numpy/core/tests/test_numeric.py6
2 files changed, 10 insertions, 10 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 01dd46c3c..eabf7962a 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)
- axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
+ axis = asarray(_validate_axis(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
if axis.size == 1:
@@ -1544,17 +1544,17 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname):
+def _validate_axis(axis, ndim, argname=None):
try:
axis = [operator.index(axis)]
except TypeError:
axis = list(axis)
- axis = [a + ndim if a < 0 else a for a in axis]
- if not builtins.all(0 <= a < ndim for a in axis):
- raise AxisError('invalid axis for this array in `%s` argument' %
- argname)
+ axis = [normalize_axis_index(ax, ndim, argname) for ax in axis]
if len(set(axis)) != len(axis):
- raise ValueError('repeated axis in `%s` argument' % argname)
+ if argname:
+ raise ValueError('repeated axis in `%s` argument' % argname)
+ else:
+ raise ValueError('repeated axis')
return axis
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 9f454e52e..56bfa04db 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2425,11 +2425,11 @@ class TestMoveaxis(TestCase):
def test_errors(self):
x = np.random.randn(1, 2, 3)
- assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
+ assert_raises_regex(np.AxisError, 'source.*out of bounds',
np.moveaxis, x, 3, 0)
- assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
+ assert_raises_regex(np.AxisError, 'source.*out of bounds',
np.moveaxis, x, -4, 0)
- assert_raises_regex(np.AxisError, 'invalid axis .* `destination`',
+ assert_raises_regex(np.AxisError, 'destination.*out of bounds',
np.moveaxis, x, 0, 5)
assert_raises_regex(ValueError, 'repeated axis in `source`',
np.moveaxis, x, [0, 0], [0, 1])