diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-26 10:22:21 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-28 20:44:36 +0100 |
commit | 4feb97e66000b17baa15b8486a7c1f9608865e78 (patch) | |
tree | 0d2ab43d91bc01706a366fea710536ff1775a3ca /numpy | |
parent | 7bf911585d7f9d56480ee1b3d70fab39fc588350 (diff) | |
download | numpy-4feb97e66000b17baa15b8486a7c1f9608865e78.tar.gz |
MAINT: have _validate_axis delegate to normalize_axis_index
They were only separate before due to a requirement of a better error message
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 14 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 6 |
2 files changed, 10 insertions, 10 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 01dd46c3c..eabf7962a 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): nullstr = a.dtype.type('') return (a != nullstr).sum(axis=axis, dtype=np.intp) - axis = asarray(_validate_axis(axis, a.ndim, 'axis')) + axis = asarray(_validate_axis(axis, a.ndim)) counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) if axis.size == 1: @@ -1544,17 +1544,17 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname): +def _validate_axis(axis, ndim, argname=None): try: axis = [operator.index(axis)] except TypeError: axis = list(axis) - axis = [a + ndim if a < 0 else a for a in axis] - if not builtins.all(0 <= a < ndim for a in axis): - raise AxisError('invalid axis for this array in `%s` argument' % - argname) + axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] if len(set(axis)) != len(axis): - raise ValueError('repeated axis in `%s` argument' % argname) + if argname: + raise ValueError('repeated axis in `%s` argument' % argname) + else: + raise ValueError('repeated axis') return axis diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 9f454e52e..56bfa04db 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2425,11 +2425,11 @@ class TestMoveaxis(TestCase): def test_errors(self): x = np.random.randn(1, 2, 3) - assert_raises_regex(np.AxisError, 'invalid axis .* `source`', + assert_raises_regex(np.AxisError, 'source.*out of bounds', np.moveaxis, x, 3, 0) - assert_raises_regex(np.AxisError, 'invalid axis .* `source`', + assert_raises_regex(np.AxisError, 'source.*out of bounds', np.moveaxis, x, -4, 0) - assert_raises_regex(np.AxisError, 'invalid axis .* `destination`', + assert_raises_regex(np.AxisError, 'destination.*out of bounds', np.moveaxis, x, 0, 5) assert_raises_regex(ValueError, 'repeated axis in `source`', np.moveaxis, x, [0, 0], [0, 1]) |