summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-21 15:58:08 -0500
committerGitHub <noreply@github.com>2017-02-21 15:58:08 -0500
commit2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0 (patch)
tree2cd08a2211a3ec1f7403c17dd175aca73a93bccb /numpy
parent070b9660282288fa8bb376533667f31613373337 (diff)
parent8d6ec65c925ebef5e0567708de1d16df39077c9d (diff)
downloadnumpy-2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0.tar.gz
Merge pull request #8584 from eric-wieser/resolve_axis
MAINT: Use the same exception for all bad axis requests
Diffstat (limited to 'numpy')
-rw-r--r--numpy/add_newdocs.py45
-rw-r--r--numpy/core/_internal.py3
-rw-r--r--numpy/core/numeric.py15
-rw-r--r--numpy/core/shape_base.py7
-rw-r--r--numpy/core/src/multiarray/common.h36
-rw-r--r--numpy/core/src/multiarray/conversion_utils.c20
-rw-r--r--numpy/core/src/multiarray/ctors.c8
-rw-r--r--numpy/core/src/multiarray/item_selection.c25
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c29
-rw-r--r--numpy/core/src/multiarray/shape.c7
-rw-r--r--numpy/core/src/umath/ufunc_object.c16
-rw-r--r--numpy/core/tests/test_multiarray.py16
-rw-r--r--numpy/core/tests/test_numeric.py16
-rw-r--r--numpy/core/tests/test_shape_base.py8
-rw-r--r--numpy/core/tests/test_ufunc.py8
-rw-r--r--numpy/lib/function_base.py13
-rw-r--r--numpy/lib/shape_base.py9
-rw-r--r--numpy/lib/tests/test_function_base.py4
-rw-r--r--numpy/linalg/linalg.py10
-rw-r--r--numpy/linalg/tests/test_linalg.py4
-rw-r--r--numpy/ma/core.py15
-rw-r--r--numpy/ma/extras.py11
-rw-r--r--numpy/ma/tests/test_core.py8
-rw-r--r--numpy/polynomial/chebyshev.py11
-rw-r--r--numpy/polynomial/hermite.py11
-rw-r--r--numpy/polynomial/hermite_e.py11
-rw-r--r--numpy/polynomial/laguerre.py11
-rw-r--r--numpy/polynomial/legendre.py11
-rw-r--r--numpy/polynomial/polynomial.py11
29 files changed, 199 insertions, 200 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 09f4e40c4..3916d1304 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -6728,6 +6728,51 @@ add_newdoc('numpy.core.multiarray', 'busday_count',
53
""")
+add_newdoc('numpy.core.multiarray', 'normalize_axis_index',
+ """
+ normalize_axis_index(axis, ndim)
+
+ Normalizes an axis index, `axis`, such that is a valid positive index into
+ the shape of array with `ndim` dimensions. Raises an AxisError with an
+ appropriate message if this is not possible.
+
+ Used internally by all axis-checking logic.
+
+ .. versionadded:: 1.13.0
+
+ Parameters
+ ----------
+ axis : int
+ The un-normalized index of the axis. Can be negative
+ ndim : int
+ The number of dimensions of the array that `axis` should be normalized
+ against
+
+ Returns
+ -------
+ normalized_axis : int
+ The normalized axis index, such that `0 <= normalized_axis < ndim`
+
+ Raises
+ ------
+ AxisError
+ If the axis index is invalid, when `-ndim <= axis < ndim` is false.
+
+ Examples
+ --------
+ >>> normalize_axis_index(0, ndim=3)
+ 0
+ >>> normalize_axis_index(1, ndim=3)
+ 1
+ >>> normalize_axis_index(-1, ndim=3)
+ 2
+
+ >>> normalize_axis_index(3, ndim=3)
+ Traceback (most recent call last):
+ ...
+ AxisError: axis 3 is out of bounds for array of dimension 3
+ """)
+
##############################################################################
#
# nd_grid instances
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 741c8bb5f..d73cdcc55 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -630,3 +630,6 @@ def _gcd(a, b):
# Exception used in shares_memory()
class TooHardError(RuntimeError):
pass
+
+class AxisError(ValueError, IndexError):
+ pass
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 97d19f008..066697f3e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -17,7 +17,7 @@ from .multiarray import (
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
- zeros)
+ zeros, normalize_axis_index)
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer
@@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
-from ._internal import TooHardError
+from ._internal import TooHardError, AxisError
bitwise_not = invert
ufunc = type(sin)
@@ -65,7 +65,7 @@ __all__ = [
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
- 'TooHardError',
+ 'TooHardError', 'AxisError'
]
@@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0):
"""
n = a.ndim
- if axis < 0:
- axis += n
+ axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
- if not (0 <= axis < n):
- raise ValueError(msg % ('axis', -n, 'axis', n, axis))
if not (0 <= start < n + 1):
- raise ValueError(msg % ('start', -n, 'start', n + 1, start))
+ raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
@@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname):
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
- raise ValueError('invalid axis for this array in `%s` argument' %
+ raise AxisError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 70afdb746..58b0dcaac 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -5,6 +5,7 @@ __all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack',
from . import numeric as _nx
from .numeric import asanyarray, newaxis
+from .multiarray import normalize_axis_index
def atleast_1d(*arys):
"""
@@ -347,11 +348,7 @@ def stack(arrays, axis=0):
raise ValueError('all input arrays must have the same shape')
result_ndim = arrays[0].ndim + 1
- if not -result_ndim <= axis < result_ndim:
- msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
- raise IndexError(msg)
- if axis < 0:
- axis += result_ndim
+ axis = normalize_axis_index(axis, result_ndim)
sl = (slice(None),) * axis + (_nx.newaxis,)
expanded_arrays = [arr[sl] for arr in arrays]
diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h
index 5e14b80a7..625ca9d76 100644
--- a/numpy/core/src/multiarray/common.h
+++ b/numpy/core/src/multiarray/common.h
@@ -134,6 +134,42 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
return 0;
}
+/*
+ * Returns -1 and sets an exception if *axis is an invalid axis for
+ * an array of dimension ndim, otherwise adjusts it in place to be
+ * 0 <= *axis < ndim, and returns 0.
+ */
+static NPY_INLINE int
+check_and_adjust_axis(int *axis, int ndim)
+{
+ /* Check that index is valid, taking into account negative indices */
+ if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
+ /*
+ * Load the exception type, if we don't already have it. Unfortunately
+ * we don't have access to npy_cache_import here
+ */
+ static PyObject *AxisError_cls = NULL;
+ if (AxisError_cls == NULL) {
+ PyObject *mod = PyImport_ImportModule("numpy.core._internal");
+
+ if (mod != NULL) {
+ AxisError_cls = PyObject_GetAttrString(mod, "AxisError");
+ Py_DECREF(mod);
+ }
+ }
+
+ PyErr_Format(AxisError_cls,
+ "axis %d is out of bounds for array of dimension %d",
+ *axis, ndim);
+ return -1;
+ }
+ /* adjust negative indices */
+ if (*axis < 0) {
+ *axis += ndim;
+ }
+ return 0;
+}
+
/*
* return true if pointer is aligned to 'alignment'
diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c
index c016bb8d1..8ed08a366 100644
--- a/numpy/core/src/multiarray/conversion_utils.c
+++ b/numpy/core/src/multiarray/conversion_utils.c
@@ -259,17 +259,10 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
PyObject *tmp = PyTuple_GET_ITEM(axis_in, i);
int axis = PyArray_PyIntAsInt_ErrMsg(tmp,
"integers are required for the axis tuple elements");
- int axis_orig = axis;
if (error_converting(axis)) {
return NPY_FAIL;
}
- if (axis < 0) {
- axis += ndim;
- }
- if (axis < 0 || axis >= ndim) {
- PyErr_Format(PyExc_ValueError,
- "'axis' entry %d is out of bounds [-%d, %d)",
- axis_orig, ndim, ndim);
+ if (check_and_adjust_axis(&axis, ndim) < 0) {
return NPY_FAIL;
}
if (out_axis_flags[axis]) {
@@ -284,20 +277,16 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
}
/* Try to interpret axis as an integer */
else {
- int axis, axis_orig;
+ int axis;
memset(out_axis_flags, 0, ndim);
axis = PyArray_PyIntAsInt_ErrMsg(axis_in,
"an integer is required for the axis");
- axis_orig = axis;
if (error_converting(axis)) {
return NPY_FAIL;
}
- if (axis < 0) {
- axis += ndim;
- }
/*
* Special case letting axis={-1,0} slip through for scalars,
* for backwards compatibility reasons.
@@ -306,10 +295,7 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
return NPY_SUCCEED;
}
- if (axis < 0 || axis >= ndim) {
- PyErr_Format(PyExc_ValueError,
- "'axis' entry %d is out of bounds [-%d, %d)",
- axis_orig, ndim, ndim);
+ if (check_and_adjust_axis(&axis, ndim) < 0) {
return NPY_FAIL;
}
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 349b59c5f..ee6b66eef 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -2793,7 +2793,6 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
{
PyObject *temp1, *temp2;
int n = PyArray_NDIM(arr);
- int axis_orig = *axis;
if (*axis == NPY_MAXDIMS || n == 0) {
if (n != 1) {
@@ -2831,12 +2830,7 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
temp2 = (PyObject *)temp1;
}
n = PyArray_NDIM((PyArrayObject *)temp2);
- if (*axis < 0) {
- *axis += n;
- }
- if ((*axis < 0) || (*axis >= n)) {
- PyErr_Format(PyExc_ValueError,
- "axis(=%d) out of bounds", axis_orig);
+ if (check_and_adjust_axis(axis, n) < 0) {
Py_DECREF(temp2);
return NULL;
}
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 08b9c5965..3c0f0782e 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1101,16 +1101,12 @@ NPY_NO_EXPORT int
PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
PyArray_SortFunc *sort;
- int axis_orig = axis;
- int n = PyArray_NDIM(op);
+ int n = PyArray_NDIM(op);
- if (axis < 0) {
- axis += n;
- }
- if (axis < 0 || axis >= n) {
- PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig);
+ if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}
+
if (PyArray_FailUnlessWriteable(op, "sort array") < 0) {
return -1;
}
@@ -1212,17 +1208,13 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis,
PyArrayObject *kthrvl;
PyArray_PartitionFunc *part;
PyArray_SortFunc *sort;
- int axis_orig = axis;
int n = PyArray_NDIM(op);
int ret;
- if (axis < 0) {
- axis += n;
- }
- if (axis < 0 || axis >= n) {
- PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig);
+ if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}
+
if (PyArray_FailUnlessWriteable(op, "partition array") < 0) {
return -1;
}
@@ -1455,12 +1447,7 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
*((npy_intp *)(PyArray_DATA(ret))) = 0;
goto finish;
}
- if (axis < 0) {
- axis += nd;
- }
- if ((axis < 0) || (axis >= nd)) {
- PyErr_Format(PyExc_ValueError,
- "axis(=%d) out of bounds", axis);
+ if (check_and_adjust_axis(&axis, nd) < 0) {
goto fail;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index f00de46c4..1c8d9b5e4 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -327,7 +327,6 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
PyArray_Descr *dtype = NULL;
PyArrayObject *ret = NULL;
PyArrayObject_fields *sliding_view = NULL;
- int orig_axis = axis;
if (narrays <= 0) {
PyErr_SetString(PyExc_ValueError,
@@ -345,13 +344,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
}
/* Handle standard Python negative indexing */
- if (axis < 0) {
- axis += ndim;
- }
-
- if (axis < 0 || axis >= ndim) {
- PyErr_Format(PyExc_IndexError,
- "axis %d out of bounds [0, %d)", orig_axis, ndim);
+ if (check_and_adjust_axis(&axis, ndim) < 0) {
return NULL;
}
@@ -4109,6 +4102,24 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_BOUNDS, 0);
}
+static PyObject *
+normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
+{
+ static char *kwlist[] = {"axis", "ndim", NULL};
+ int axis;
+ int ndim;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii", kwlist,
+ &axis, &ndim)) {
+ return NULL;
+ }
+
+ if(check_and_adjust_axis(&axis, ndim) < 0) {
+ return NULL;
+ }
+
+ return PyInt_FromLong(axis);
+}
static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version",
@@ -4284,6 +4295,8 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"unpackbits", (PyCFunction)io_unpack,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"normalize_axis_index", (PyCFunction)normalize_axis_index,
+ METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c
index 3bee562be..5207513bf 100644
--- a/numpy/core/src/multiarray/shape.c
+++ b/numpy/core/src/multiarray/shape.c
@@ -705,12 +705,7 @@ PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)
}
for (i = 0; i < n; i++) {
axis = axes[i];
- if (axis < 0) {
- axis = PyArray_NDIM(ap) + axis;
- }
- if (axis < 0 || axis >= PyArray_NDIM(ap)) {
- PyErr_SetString(PyExc_ValueError,
- "invalid axis for this array");
+ if (check_and_adjust_axis(&axis, PyArray_NDIM(ap)) < 0) {
return NULL;
}
if (reverse_permutation[axis] != -1) {
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 0bae2d591..af4ce12db 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4036,12 +4036,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
Py_DECREF(mp);
return NULL;
}
- if (axis < 0) {
- axis += ndim;
- }
- if (axis < 0 || axis >= ndim) {
- PyErr_SetString(PyExc_ValueError,
- "'axis' entry is out of bounds");
+ if (check_and_adjust_axis(&axis, ndim) < 0) {
Py_XDECREF(otype);
Py_DECREF(mp);
return NULL;
@@ -4058,18 +4053,11 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
Py_DECREF(mp);
return NULL;
}
- if (axis < 0) {
- axis += ndim;
- }
/* Special case letting axis={0 or -1} slip through for scalars */
if (ndim == 0 && (axis == 0 || axis == -1)) {
axis = 0;
}
- else if (axis < 0 || axis >= ndim) {
- PyErr_SetString(PyExc_ValueError,
- "'axis' entry is out of bounds");
- Py_XDECREF(otype);
- Py_DECREF(mp);
+ else if (check_and_adjust_axis(&axis, ndim) < 0) {
return NULL;
}
axes[0] = (int)axis;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 8229f1e1a..fa5051ba7 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2013,13 +2013,13 @@ class TestMethods(TestCase):
d = np.array([2, 1])
d.partition(0, kind=k)
assert_raises(ValueError, d.partition, 2)
- assert_raises(ValueError, d.partition, 3, axis=1)
+ assert_raises(np.AxisError, d.partition, 3, axis=1)
assert_raises(ValueError, np.partition, d, 2)
- assert_raises(ValueError, np.partition, d, 2, axis=1)
+ assert_raises(np.AxisError, np.partition, d, 2, axis=1)
assert_raises(ValueError, d.argpartition, 2)
- assert_raises(ValueError, d.argpartition, 3, axis=1)
+ assert_raises(np.AxisError, d.argpartition, 3, axis=1)
assert_raises(ValueError, np.argpartition, d, 2)
- assert_raises(ValueError, np.argpartition, d, 2, axis=1)
+ assert_raises(np.AxisError, np.argpartition, d, 2, axis=1)
d = np.arange(10).reshape((2, 5))
d.partition(1, axis=0, kind=k)
d.partition(4, axis=1, kind=k)
@@ -3522,8 +3522,8 @@ class TestArgmin(TestCase):
class TestMinMax(TestCase):
def test_scalar(self):
- assert_raises(ValueError, np.amax, 1, 1)
- assert_raises(ValueError, np.amin, 1, 1)
+ assert_raises(np.AxisError, np.amax, 1, 1)
+ assert_raises(np.AxisError, np.amin, 1, 1)
assert_equal(np.amax(1, axis=0), 1)
assert_equal(np.amin(1, axis=0), 1)
@@ -3531,7 +3531,7 @@ class TestMinMax(TestCase):
assert_equal(np.amin(1, axis=None), 1)
def test_axis(self):
- assert_raises(ValueError, np.amax, [1, 2, 3], 1000)
+ assert_raises(np.AxisError, np.amax, [1, 2, 3], 1000)
assert_equal(np.amax([[1, 2, 3]], axis=1), 3)
def test_datetime(self):
@@ -3793,7 +3793,7 @@ class TestLexsort(TestCase):
def test_invalid_axis(self): # gh-7528
x = np.linspace(0., 1., 42*3).reshape(42, 3)
- assert_raises(ValueError, np.lexsort, x, axis=2)
+ assert_raises(np.AxisError, np.lexsort, x, axis=2)
class TestIO(object):
"""Test tofile, fromfile, tobytes, and fromstring"""
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 4aa6bed33..906280e15 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1010,7 +1010,7 @@ class TestNonzero(TestCase):
assert_raises(ValueError, np.count_nonzero, m, axis=(1, 1))
assert_raises(TypeError, np.count_nonzero, m, axis='foo')
- assert_raises(ValueError, np.count_nonzero, m, axis=3)
+ assert_raises(np.AxisError, np.count_nonzero, m, axis=3)
assert_raises(TypeError, np.count_nonzero,
m, axis=np.array([[1], [2]]))
@@ -2323,10 +2323,10 @@ class TestRollaxis(TestCase):
def test_exceptions(self):
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4)
- assert_raises(ValueError, np.rollaxis, a, -5, 0)
- assert_raises(ValueError, np.rollaxis, a, 0, -5)
- assert_raises(ValueError, np.rollaxis, a, 4, 0)
- assert_raises(ValueError, np.rollaxis, a, 0, 5)
+ assert_raises(np.AxisError, np.rollaxis, a, -5, 0)
+ assert_raises(np.AxisError, np.rollaxis, a, 0, -5)
+ assert_raises(np.AxisError, np.rollaxis, a, 4, 0)
+ assert_raises(np.AxisError, np.rollaxis, a, 0, 5)
def test_results(self):
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
@@ -2413,11 +2413,11 @@ class TestMoveaxis(TestCase):
def test_errors(self):
x = np.random.randn(1, 2, 3)
- assert_raises_regex(ValueError, 'invalid axis .* `source`',
+ assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
np.moveaxis, x, 3, 0)
- assert_raises_regex(ValueError, 'invalid axis .* `source`',
+ assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
np.moveaxis, x, -4, 0)
- assert_raises_regex(ValueError, 'invalid axis .* `destination`',
+ assert_raises_regex(np.AxisError, 'invalid axis .* `destination`',
np.moveaxis, x, 0, 5)
assert_raises_regex(ValueError, 'repeated axis in `source`',
np.moveaxis, x, [0, 0], [0, 1])
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index ac8dc1eea..727608a17 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -184,8 +184,8 @@ class TestConcatenate(TestCase):
for ndim in [1, 2, 3]:
a = np.ones((1,)*ndim)
np.concatenate((a, a), axis=0) # OK
- assert_raises(IndexError, np.concatenate, (a, a), axis=ndim)
- assert_raises(IndexError, np.concatenate, (a, a), axis=-(ndim + 1))
+ assert_raises(np.AxisError, np.concatenate, (a, a), axis=ndim)
+ assert_raises(np.AxisError, np.concatenate, (a, a), axis=-(ndim + 1))
# Scalars cannot be concatenated
assert_raises(ValueError, concatenate, (0,))
@@ -294,8 +294,8 @@ def test_stack():
expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)]
for axis, expected_shape in zip(axes, expected_shapes):
assert_equal(np.stack(arrays, axis).shape, expected_shape)
- assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2)
- assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3)
+ assert_raises_regex(np.AxisError, 'out of bounds', stack, arrays, axis=2)
+ assert_raises_regex(np.AxisError, 'out of bounds', stack, arrays, axis=-3)
# all shapes for 2d input
arrays = [np.random.randn(3, 4) for _ in range(10)]
axes = [0, 1, 2, -1, -2, -3]
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 3fea68700..f7b66f90c 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -703,14 +703,14 @@ class TestUfunc(TestCase):
def test_axis_out_of_bounds(self):
a = np.array([False, False])
- assert_raises(ValueError, a.all, axis=1)
+ assert_raises(np.AxisError, a.all, axis=1)
a = np.array([False, False])
- assert_raises(ValueError, a.all, axis=-2)
+ assert_raises(np.AxisError, a.all, axis=-2)
a = np.array([False, False])
- assert_raises(ValueError, a.any, axis=1)
+ assert_raises(np.AxisError, a.any, axis=1)
a = np.array([False, False])
- assert_raises(ValueError, a.any, axis=-2)
+ assert_raises(np.AxisError, a.any, axis=-2)
def test_scalar_reduction(self):
# The functions 'sum', 'prod', etc allow specifying axis=0
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index ae1420b72..4d1ffbccc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -25,7 +25,7 @@ from numpy.core.numerictypes import typecodes, number
from numpy.lib.twodim_base import diag
from .utils import deprecate
from numpy.core.multiarray import (
- _insert, add_docstring, digitize, bincount,
+ _insert, add_docstring, digitize, bincount, normalize_axis_index,
interp as compiled_interp, interp_complex as compiled_interp_complex
)
from numpy.core.umath import _add_newdoc_ufunc as add_newdoc_ufunc
@@ -4828,14 +4828,7 @@ def insert(arr, obj, values, axis=None):
arr = arr.ravel()
ndim = arr.ndim
axis = ndim - 1
- else:
- if ndim > 0 and (axis < -ndim or axis >= ndim):
- raise IndexError(
- "axis %i is out of bounds for an array of "
- "dimension %i" % (axis, ndim))
- if (axis < 0):
- axis += ndim
- if (ndim == 0):
+ elif ndim == 0:
# 2013-09-24, 1.9
warnings.warn(
"in the future the special handling of scalars will be removed "
@@ -4846,6 +4839,8 @@ def insert(arr, obj, values, axis=None):
return wrap(arr)
else:
return arr
+ else:
+ axis = normalize_axis_index(axis, ndim)
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 58e13533b..62798286f 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -7,6 +7,7 @@ from numpy.core.numeric import (
asarray, zeros, outer, concatenate, isscalar, array, asanyarray
)
from numpy.core.fromnumeric import product, reshape, transpose
+from numpy.core.multiarray import normalize_axis_index
from numpy.core import vstack, atleast_3d
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -96,10 +97,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
# handle negative axes
arr = asanyarray(arr)
nd = arr.ndim
- if not (-nd <= axis < nd):
- raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd))
- if axis < 0:
- axis += nd
+ axis = normalize_axis_index(axis, nd)
# arr, with the iteration axis at the end
in_dims = list(range(nd))
@@ -289,8 +287,7 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
- if axis < 0:
- axis = axis + len(shape) + 1
+ axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
row_stack = vstack
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index f69c24d59..d914260ad 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -466,8 +466,8 @@ class TestInsert(TestCase):
insert(a, 1, a[:, 2,:], axis=1))
# invalid axis value
- assert_raises(IndexError, insert, a, 1, a[:, 2, :], axis=3)
- assert_raises(IndexError, insert, a, 1, a[:, 2, :], axis=-4)
+ assert_raises(np.AxisError, insert, a, 1, a[:, 2, :], axis=3)
+ assert_raises(np.AxisError, insert, a, 1, a[:, 2, :], axis=-4)
# negative axis value
a = np.arange(24).reshape((2, 3, 4))
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 7b4bbf416..6002c63b9 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -25,6 +25,7 @@ from numpy.core import (
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
broadcast, atleast_2d, intp, asanyarray, isscalar, object_
)
+from numpy.core.multiarray import normalize_axis_index
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
from numpy.matrixlib.defmatrix import matrix_power
@@ -2236,13 +2237,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord)
elif len(axis) == 2:
row_axis, col_axis = axis
- if row_axis < 0:
- row_axis += nd
- if col_axis < 0:
- col_axis += nd
- if not (0 <= row_axis < nd and 0 <= col_axis < nd):
- raise ValueError('Invalid axis %r for an array with shape %r' %
- (axis, x.shape))
+ row_axis = normalize_axis_index(row_axis, nd)
+ col_axis = normalize_axis_index(col_axis, nd)
if row_axis == col_axis:
raise ValueError('Duplicate axes given.')
if ord == 2:
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 6f289f51f..fc4f98ed7 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -1198,8 +1198,8 @@ class _TestNorm(object):
assert_raises(ValueError, norm, B, order, (1, 2))
# Invalid axis
- assert_raises(ValueError, norm, B, None, 3)
- assert_raises(ValueError, norm, B, None, (2, 3))
+ assert_raises(np.AxisError, norm, B, None, 3)
+ assert_raises(np.AxisError, norm, B, None, (2, 3))
assert_raises(ValueError, norm, B, None, (0, 1, 2))
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index a6f474b95..1b25725d1 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -41,6 +41,7 @@ from numpy.compat import (
getargspec, formatargspec, long, basestring, unicode, bytes, sixu
)
from numpy import expand_dims as n_expand_dims
+from numpy.core.multiarray import normalize_axis_index
if sys.version_info[0] >= 3:
@@ -3902,7 +3903,9 @@ class MaskedArray(ndarray):
axis = None
try:
mask = mask.view((bool_, len(self.dtype))).all(axis)
- except ValueError:
+ except (ValueError, np.AxisError):
+ # TODO: what error are we trying to catch here?
+ # invalid axis, or invalid view?
mask = np.all([[f[n].all() for n in mask.dtype.names]
for f in mask], axis=axis)
check._mask = mask
@@ -3938,7 +3941,9 @@ class MaskedArray(ndarray):
axis = None
try:
mask = mask.view((bool_, len(self.dtype))).all(axis)
- except ValueError:
+ except (ValueError, np.AxisError):
+ # TODO: what error are we trying to catch here?
+ # invalid axis, or invalid view?
mask = np.all([[f[n].all() for n in mask.dtype.names]
for f in mask], axis=axis)
check._mask = mask
@@ -4340,7 +4345,7 @@ class MaskedArray(ndarray):
if self.shape is ():
if axis not in (None, 0):
- raise ValueError("'axis' entry is out of bounds")
+ raise np.AxisError("'axis' entry is out of bounds")
return 1
elif axis is None:
if kwargs.get('keepdims', False):
@@ -4348,11 +4353,9 @@ class MaskedArray(ndarray):
return self.size
axes = axis if isinstance(axis, tuple) else (axis,)
- axes = tuple(a if a >= 0 else self.ndim + a for a in axes)
+ axes = tuple(normalize_axis_index(a, self.ndim) for a in axes)
if len(axes) != len(set(axes)):
raise ValueError("duplicate value in 'axis'")
- if builtins.any(a < 0 or a >= self.ndim for a in axes):
- raise ValueError("'axis' entry is out of bounds")
items = 1
for ax in axes:
items *= self.shape[ax]
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 29a15633d..7149b525b 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -36,6 +36,7 @@ from .core import (
import numpy as np
from numpy import ndarray, array as nxarray
import numpy.core.umath as umath
+from numpy.core.multiarray import normalize_axis_index
from numpy.lib.function_base import _ureduce
from numpy.lib.index_tricks import AxisConcatenator
@@ -380,11 +381,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
arr = array(arr, copy=False, subok=True)
nd = arr.ndim
- if axis < 0:
- axis += nd
- if (axis >= nd):
- raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d."
- % (axis, nd))
+ axis = normalize_axis_index(axis, nd)
ind = [0] * (nd - 1)
i = np.zeros(nd, 'O')
indlist = list(range(nd))
@@ -717,8 +714,8 @@ def _median(a, axis=None, out=None, overwrite_input=False):
if axis is None:
axis = 0
- elif axis < 0:
- axis += asorted.ndim
+ else:
+ axis = normalize_axis_index(axis, asorted.ndim)
if asorted.ndim == 1:
counts = count(asorted)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index f72ddc5ea..9d8002ed0 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1030,7 +1030,7 @@ class TestMaskedArrayArithmetic(TestCase):
res = count(ott, 0)
assert_(isinstance(res, ndarray))
assert_(res.dtype.type is np.intp)
- assert_raises(ValueError, ott.count, axis=1)
+ assert_raises(np.AxisError, ott.count, axis=1)
def test_minmax_func(self):
# Tests minimum and maximum.
@@ -4409,7 +4409,7 @@ class TestOptionalArgs(TestCase):
assert_equal(count(a, axis=(0,1), keepdims=True), 4*ones((1,1,4)))
assert_equal(count(a, axis=-2), 2*ones((2,4)))
assert_raises(ValueError, count, a, axis=(1,1))
- assert_raises(ValueError, count, a, axis=3)
+ assert_raises(np.AxisError, count, a, axis=3)
# check the 'nomask' path
a = np.ma.array(d, mask=nomask)
@@ -4423,13 +4423,13 @@ class TestOptionalArgs(TestCase):
assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4)))
assert_equal(count(a, axis=-2), 3*ones((2,4)))
assert_raises(ValueError, count, a, axis=(1,1))
- assert_raises(ValueError, count, a, axis=3)
+ assert_raises(np.AxisError, count, a, axis=3)
# check the 'masked' singleton
assert_equal(count(np.ma.masked), 0)
# check 0-d arrays do not allow axis > 0
- assert_raises(ValueError, count, np.ma.array(1), axis=1)
+ assert_raises(np.AxisError, count, np.ma.array(1), axis=1)
class TestMaskedConstant(TestCase):
diff --git a/numpy/polynomial/chebyshev.py b/numpy/polynomial/chebyshev.py
index 3babb8fc2..49d0302e0 100644
--- a/numpy/polynomial/chebyshev.py
+++ b/numpy/polynomial/chebyshev.py
@@ -90,6 +90,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -936,10 +937,7 @@ def chebder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -1064,10 +1062,7 @@ def chebint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
diff --git a/numpy/polynomial/hermite.py b/numpy/polynomial/hermite.py
index 0ebae2027..a03fe722c 100644
--- a/numpy/polynomial/hermite.py
+++ b/numpy/polynomial/hermite.py
@@ -62,6 +62,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -700,10 +701,7 @@ def hermder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -822,10 +820,7 @@ def hermint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
diff --git a/numpy/polynomial/hermite_e.py b/numpy/polynomial/hermite_e.py
index a09b66670..2a29d61cf 100644
--- a/numpy/polynomial/hermite_e.py
+++ b/numpy/polynomial/hermite_e.py
@@ -62,6 +62,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -699,10 +700,7 @@ def hermeder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -821,10 +819,7 @@ def hermeint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
diff --git a/numpy/polynomial/laguerre.py b/numpy/polynomial/laguerre.py
index dfa997254..c9e1302e1 100644
--- a/numpy/polynomial/laguerre.py
+++ b/numpy/polynomial/laguerre.py
@@ -62,6 +62,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -697,10 +698,7 @@ def lagder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -822,10 +820,7 @@ def lagint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
diff --git a/numpy/polynomial/legendre.py b/numpy/polynomial/legendre.py
index fdaa56e0c..fa578360e 100644
--- a/numpy/polynomial/legendre.py
+++ b/numpy/polynomial/legendre.py
@@ -86,6 +86,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -736,10 +737,7 @@ def legder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -864,10 +862,7 @@ def legint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
diff --git a/numpy/polynomial/polynomial.py b/numpy/polynomial/polynomial.py
index 19b085eaf..c357b48c9 100644
--- a/numpy/polynomial/polynomial.py
+++ b/numpy/polynomial/polynomial.py
@@ -66,6 +66,7 @@ __all__ = [
import warnings
import numpy as np
import numpy.linalg as la
+from numpy.core.multiarray import normalize_axis_index
from . import polyutils as pu
from ._polybase import ABCPolyBase
@@ -540,10 +541,7 @@ def polyder(c, m=1, scl=1, axis=0):
raise ValueError("The order of derivation must be non-negative")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c
@@ -658,10 +656,7 @@ def polyint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
raise ValueError("Too many integration constants")
if iaxis != axis:
raise ValueError("The axis must be integer")
- if not -c.ndim <= iaxis < c.ndim:
- raise ValueError("The axis is out of range")
- if iaxis < 0:
- iaxis += c.ndim
+ iaxis = normalize_axis_index(iaxis, c.ndim)
if cnt == 0:
return c