summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c27
-rw-r--r--numpy/core/tests/test_numeric.py33
2 files changed, 58 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 9351b5fc7..762563eb5 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2203,9 +2203,13 @@ PyArray_Nonzero(PyArrayObject *self)
PyArrayObject *ret = NULL;
PyObject *ret_tuple;
npy_intp ret_dims[2];
- PyArray_NonzeroFunc *nonzero = PyArray_DESCR(self)->f->nonzero;
+
+ PyArray_NonzeroFunc *nonzero;
+ PyArray_Descr *dtype;
+
npy_intp nonzero_count;
npy_intp added_count = 0;
+ int needs_api;
int is_bool;
NpyIter *iter;
@@ -2213,6 +2217,10 @@ PyArray_Nonzero(PyArrayObject *self)
NpyIter_GetMultiIndexFunc *get_multi_index;
char **dataptr;
+ dtype = PyArray_DESCR(self);
+ nonzero = dtype->f->nonzero;
+ needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI);
+
/* Special case - nonzero(zero_d) is nonzero(atleast_1d(zero_d)) */
if (ndim == 0) {
char const* msg;
@@ -2283,7 +2291,9 @@ PyArray_Nonzero(PyArrayObject *self)
goto finish;
}
- NPY_BEGIN_THREADS_THRESHOLDED(count);
+ if (!needs_api) {
+ NPY_BEGIN_THREADS_THRESHOLDED(count);
+ }
/* avoid function call for bool */
if (is_bool) {
@@ -2324,6 +2334,9 @@ PyArray_Nonzero(PyArrayObject *self)
}
*multi_index++ = j;
}
+ if (needs_api && PyErr_Occurred()) {
+ break;
+ }
data += stride;
}
}
@@ -2364,6 +2377,8 @@ PyArray_Nonzero(PyArrayObject *self)
Py_DECREF(ret);
return NULL;
}
+
+ needs_api = NpyIter_IterationNeedsAPI(iter);
NPY_BEGIN_THREADS_NDITER(iter);
@@ -2390,6 +2405,9 @@ PyArray_Nonzero(PyArrayObject *self)
get_multi_index(iter, multi_index);
multi_index += ndim;
}
+ if (needs_api && PyErr_Occurred()) {
+ break;
+ }
} while(iternext(iter));
}
@@ -2399,6 +2417,11 @@ PyArray_Nonzero(PyArrayObject *self)
NpyIter_Deallocate(iter);
finish:
+ if (PyErr_Occurred()) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
/* if executed `nonzero()` check for miscount due to side-effect */
if (!is_bool && added_count != nonzero_count) {
PyErr_SetString(PyExc_RuntimeError,
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 935b84234..3e85054b7 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1254,6 +1254,39 @@ class TestNonzero(object):
a = np.array([[False], [TrueThenFalse()]])
assert_raises(RuntimeError, np.nonzero, a)
+ def test_nonzero_exception_safe(self):
+ # gh-13930
+
+ class ThrowsAfter:
+ def __init__(self, iters):
+ self.iters_left = iters
+
+ def __bool__(self):
+ if self.iters_left == 0:
+ raise ValueError("called `iters` times")
+
+ self.iters_left -= 1
+ return True
+
+ """
+ Test that a ValueError is raised instead of a SystemError
+
+ If the __bool__ function is called after the error state is set,
+ Python (cpython) will raise a SystemError.
+ """
+
+ # assert that an exception in first pass is handled correctly
+ a = np.array([ThrowsAfter(5)]*10)
+ assert_raises(ValueError, np.nonzero, a)
+
+ # raise exception in second pass for 1-dimensional loop
+ a = np.array([ThrowsAfter(15)]*10)
+ assert_raises(ValueError, np.nonzero, a)
+
+ # raise exception in second pass for n-dimensional loop
+ a = np.array([[ThrowsAfter(15)]]*10)
+ assert_raises(ValueError, np.nonzero, a)
+
class TestIndex(object):
def test_boolean(self):