summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2017-10-18 09:32:40 -0600
committerGitHub <noreply@github.com>2017-10-18 09:32:40 -0600
commitaca47d713624bb09e73db014811f8c1be31dfdcd (patch)
tree4ed14c77c72b9a4fc727145abdcb7a978e26c849
parentce65f115c409a4a8089ace496e15bbfca566070b (diff)
parent39a7681bfe3fb78f549d5893333eecb7fd41fe29 (diff)
downloadnumpy-aca47d713624bb09e73db014811f8c1be31dfdcd.tar.gz
Merge pull request #9884 from eric-wieser/unravel_index-0d
BUG: Allow `unravel_index(0, ())` to return ()
-rw-r--r--numpy/core/src/multiarray/compiled_base.c34
-rw-r--r--numpy/lib/tests/test_index_tricks.py13
2 files changed, 36 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index 36ef1d1c4..580578562 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -1135,8 +1135,11 @@ unravel_index_loop_corder(int unravel_ndim, npy_intp *unravel_dims,
}
NPY_END_ALLOW_THREADS;
if (invalid) {
- PyErr_SetString(PyExc_ValueError,
- "invalid entry in index array");
+ PyErr_Format(PyExc_ValueError,
+ "index %" NPY_INTP_FMT " is out of bounds for array with size "
+ "%" NPY_INTP_FMT,
+ val, unravel_size
+ );
return NPY_FAIL;
}
return NPY_SUCCEED;
@@ -1169,8 +1172,11 @@ unravel_index_loop_forder(int unravel_ndim, npy_intp *unravel_dims,
}
NPY_END_ALLOW_THREADS;
if (invalid) {
- PyErr_SetString(PyExc_ValueError,
- "invalid entry in index array");
+ PyErr_Format(PyExc_ValueError,
+ "index %" NPY_INTP_FMT " is out of bounds for array with size "
+ "%" NPY_INTP_FMT,
+ val, unravel_size
+ );
return NPY_FAIL;
}
return NPY_SUCCEED;
@@ -1202,12 +1208,6 @@ arr_unravel_index(PyObject *self, PyObject *args, PyObject *kwds)
goto fail;
}
- if (dimensions.len == 0) {
- PyErr_SetString(PyExc_ValueError,
- "dims must have at least one value");
- goto fail;
- }
-
unravel_size = PyArray_MultiplyList(dimensions.ptr, dimensions.len);
if (!PyArray_Check(indices0)) {
@@ -1328,6 +1328,20 @@ arr_unravel_index(PyObject *self, PyObject *args, PyObject *kwds)
goto fail;
}
+
+ if (dimensions.len == 0 && PyArray_NDIM(indices) != 0) {
+ /*
+ * There's no index meaning "take the only element 10 times"
+ * on a zero-d array, so we have no choice but to error. (See gh-580)
+ *
+ * Do this check after iterating, so we give a better error message
+ * for invalid indices.
+ */
+ PyErr_SetString(PyExc_ValueError,
+ "multiple indices are not supported for 0d arrays");
+ goto fail;
+ }
+
/* Now make a tuple of views, one per index */
ret_tuple = PyTuple_New(dimensions.len);
if (ret_tuple == NULL) {
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 452b3d6a2..1d5efef86 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -3,7 +3,8 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
run_module_suite, assert_, assert_equal, assert_array_equal,
- assert_almost_equal, assert_array_almost_equal, assert_raises
+ assert_almost_equal, assert_array_almost_equal, assert_raises,
+ assert_raises_regex
)
from numpy.lib.index_tricks import (
mgrid, ndenumerate, fill_diagonal, diag_indices, diag_indices_from,
@@ -114,6 +115,16 @@ class TestRavelUnravelIndex(object):
assert_(y.flags.writeable)
+ def test_0d(self):
+ # gh-580
+ x = np.unravel_index(0, ())
+ assert_equal(x, ())
+
+ assert_raises_regex(ValueError, "0d array", np.unravel_index, [0], ())
+ assert_raises_regex(
+ ValueError, "out of bounds", np.unravel_index, [1], ())
+
+
class TestGrid(object):
def test_basic(self):
a = mgrid[-1:1:10j]