summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-21 15:58:08 -0500
committerGitHub <noreply@github.com>2017-02-21 15:58:08 -0500
commit2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0 (patch)
tree2cd08a2211a3ec1f7403c17dd175aca73a93bccb /numpy/core/numeric.py
parent070b9660282288fa8bb376533667f31613373337 (diff)
parent8d6ec65c925ebef5e0567708de1d16df39077c9d (diff)
downloadnumpy-2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0.tar.gz
Merge pull request #8584 from eric-wieser/resolve_axis
MAINT: Use the same exception for all bad axis requests
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 97d19f008..066697f3e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -17,7 +17,7 @@ from .multiarray import (
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
- zeros)
+ zeros, normalize_axis_index)
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer
@@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
-from ._internal import TooHardError
+from ._internal import TooHardError, AxisError
bitwise_not = invert
ufunc = type(sin)
@@ -65,7 +65,7 @@ __all__ = [
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
- 'TooHardError',
+ 'TooHardError', 'AxisError'
]
@@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0):
"""
n = a.ndim
- if axis < 0:
- axis += n
+ axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
- if not (0 <= axis < n):
- raise ValueError(msg % ('axis', -n, 'axis', n, axis))
if not (0 <= start < n + 1):
- raise ValueError(msg % ('start', -n, 'start', n + 1, start))
+ raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
@@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname):
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
- raise ValueError('invalid axis for this array in `%s` argument' %
+ raise AxisError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)