diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-09-24 13:43:51 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-24 13:43:51 -0600 |
commit | c2372aff1ea85e8e1d64d8b5ced7e362fd0f4a5a (patch) | |
tree | f7e51f64240912eb4627ff9c657045f1ce716da7 | |
parent | cfe8772057e16994f140f587ac6c2d0e12b744f5 (diff) | |
parent | 8d4942bd6a682d0c6a86c503c6af0c08a5333adb (diff) | |
download | numpy-c2372aff1ea85e8e1d64d8b5ced7e362fd0f4a5a.tar.gz |
Merge pull request #9300 from ahaldane/fix_obj_nonzero
BUG: PyArray_CountNonzero does not check for exceptions
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 38 | ||||
-rw-r--r-- | numpy/core/src/multiarray/sequence.c | 37 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 13 |
4 files changed, 53 insertions, 39 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 811653e85..3b5d76362 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2068,6 +2068,8 @@ PyArray_CountNonzero(PyArrayObject *self) char *data; npy_intp stride, count; npy_intp nonzero_count = 0; + int needs_api = 0; + PyArray_Descr *dtype; NpyIter *iter; NpyIter_IterNextFunc *iternext; @@ -2076,22 +2078,38 @@ PyArray_CountNonzero(PyArrayObject *self) NPY_BEGIN_THREADS_DEF; /* Special low-overhead version specific to the boolean type */ - if (PyArray_DESCR(self)->type_num == NPY_BOOL) { + dtype = PyArray_DESCR(self); + if (dtype->type_num == NPY_BOOL) { return count_boolean_trues(PyArray_NDIM(self), PyArray_DATA(self), PyArray_DIMS(self), PyArray_STRIDES(self)); } - nonzero = PyArray_DESCR(self)->f->nonzero; /* If it's a trivial one-dimensional loop, don't use an iterator */ if (PyArray_TRIVIALLY_ITERABLE(self)) { + needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI); PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride); - while (count--) { - if (nonzero(data, self)) { - ++nonzero_count; + if (needs_api){ + while (count--) { + if (nonzero(data, self)) { + ++nonzero_count; + } + if (PyErr_Occurred()) { + return -1; + } + data += stride; } - data += stride; + } + else { + NPY_BEGIN_THREADS_THRESHOLDED(count); + while (count--) { + if (nonzero(data, self)) { + ++nonzero_count; + } + data += stride; + } + NPY_END_THREADS; } return nonzero_count; @@ -2116,6 +2134,7 @@ PyArray_CountNonzero(PyArrayObject *self) if (iter == NULL) { return -1; } + needs_api = NpyIter_IterationNeedsAPI(iter); /* Get the pointers for inner loop iteration */ iternext = NpyIter_GetIterNext(iter, NULL); @@ -2140,16 +2159,21 @@ PyArray_CountNonzero(PyArrayObject *self) if (nonzero(data, self)) { ++nonzero_count; } + if (needs_api && PyErr_Occurred()) { + nonzero_count = -1; + goto finish; + } data += stride; } } while(iternext(iter)); +finish: NPY_END_THREADS; NpyIter_Deallocate(iter); - return PyErr_Occurred() ? -1 : nonzero_count; + return nonzero_count; } /*NUMPY_API diff --git a/numpy/core/src/multiarray/sequence.c b/numpy/core/src/multiarray/sequence.c index 55b72c198..4769bdad9 100644 --- a/numpy/core/src/multiarray/sequence.c +++ b/numpy/core/src/multiarray/sequence.c @@ -15,9 +15,7 @@ #include "mapping.h" #include "sequence.h" - -static int -array_any_nonzero(PyArrayObject *mp); +#include "calculation.h" /************************************************************************* **************** Implement Sequence Protocol ************************** @@ -32,16 +30,18 @@ array_contains(PyArrayObject *self, PyObject *el) { /* equivalent to (self == el).any() */ - PyObject *res; int ret; + PyObject *res, *any; res = PyArray_EnsureAnyArray(PyObject_RichCompare((PyObject *)self, el, Py_EQ)); if (res == NULL) { return -1; } - ret = array_any_nonzero((PyArrayObject *)res); + any = PyArray_Any((PyArrayObject *)res, NPY_MAXDIMS, NULL); Py_DECREF(res); + ret = PyObject_IsTrue(any); + Py_DECREF(any); return ret; } @@ -61,30 +61,3 @@ NPY_NO_EXPORT PySequenceMethods array_as_sequence = { /****************** End of Sequence Protocol ****************************/ -/* - * Helpers - */ - -/* Array evaluates as "TRUE" if any of the elements are non-zero*/ -static int -array_any_nonzero(PyArrayObject *arr) -{ - npy_intp counter; - PyArrayIterObject *it; - npy_bool anyTRUE = NPY_FALSE; - - it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)arr); - if (it == NULL) { - return anyTRUE; - } - counter = it->size; - while (counter--) { - if (PyArray_DESCR(arr)->f->nonzero(it->dataptr, arr)) { - anyTRUE = NPY_TRUE; - break; - } - PyArray_ITER_NEXT(it); - } - Py_DECREF(it); - return anyTRUE; -} diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 88456a816..ea5421ec2 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2839,6 +2839,10 @@ class TestMethods(object): e = np.array(['1+1j'], 'U') assert_raises(TypeError, complex, e) +class TestCequenceMethods(object): + def test_array_contains(self): + assert_(4.0 in np.arange(16.).reshape(4,4)) + assert_(20.0 not in np.arange(16.).reshape(4,4)) class TestBinop(object): def test_inplace(self): diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index e8c637179..b53c4338e 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1120,6 +1120,19 @@ class TestNonzero(object): assert_equal(m.nonzero(), tgt) + def test_nonzero_invalid_object(self): + # gh-9295 + a = np.array([np.array([1, 2]), 3]) + assert_raises(ValueError, np.nonzero, a) + + class BoolErrors: + def __bool__(self): + raise ValueError("Not allowed") + def __nonzero__(self): + raise ValueError("Not allowed") + + assert_raises(ValueError, np.nonzero, np.array([BoolErrors()])) + class TestIndex(object): def test_boolean(self): |