summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/item_selection.c38
-rw-r--r--numpy/core/src/multiarray/sequence.c37
-rw-r--r--numpy/core/tests/test_multiarray.py4
-rw-r--r--numpy/core/tests/test_numeric.py13
4 files changed, 53 insertions, 39 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 811653e85..3b5d76362 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2068,6 +2068,8 @@ PyArray_CountNonzero(PyArrayObject *self)
char *data;
npy_intp stride, count;
npy_intp nonzero_count = 0;
+ int needs_api = 0;
+ PyArray_Descr *dtype;
NpyIter *iter;
NpyIter_IterNextFunc *iternext;
@@ -2076,22 +2078,38 @@ PyArray_CountNonzero(PyArrayObject *self)
NPY_BEGIN_THREADS_DEF;
/* Special low-overhead version specific to the boolean type */
- if (PyArray_DESCR(self)->type_num == NPY_BOOL) {
+ dtype = PyArray_DESCR(self);
+ if (dtype->type_num == NPY_BOOL) {
return count_boolean_trues(PyArray_NDIM(self), PyArray_DATA(self),
PyArray_DIMS(self), PyArray_STRIDES(self));
}
-
nonzero = PyArray_DESCR(self)->f->nonzero;
/* If it's a trivial one-dimensional loop, don't use an iterator */
if (PyArray_TRIVIALLY_ITERABLE(self)) {
+ needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI);
PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride);
- while (count--) {
- if (nonzero(data, self)) {
- ++nonzero_count;
+ if (needs_api){
+ while (count--) {
+ if (nonzero(data, self)) {
+ ++nonzero_count;
+ }
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ data += stride;
}
- data += stride;
+ }
+ else {
+ NPY_BEGIN_THREADS_THRESHOLDED(count);
+ while (count--) {
+ if (nonzero(data, self)) {
+ ++nonzero_count;
+ }
+ data += stride;
+ }
+ NPY_END_THREADS;
}
return nonzero_count;
@@ -2116,6 +2134,7 @@ PyArray_CountNonzero(PyArrayObject *self)
if (iter == NULL) {
return -1;
}
+ needs_api = NpyIter_IterationNeedsAPI(iter);
/* Get the pointers for inner loop iteration */
iternext = NpyIter_GetIterNext(iter, NULL);
@@ -2140,16 +2159,21 @@ PyArray_CountNonzero(PyArrayObject *self)
if (nonzero(data, self)) {
++nonzero_count;
}
+ if (needs_api && PyErr_Occurred()) {
+ nonzero_count = -1;
+ goto finish;
+ }
data += stride;
}
} while(iternext(iter));
+finish:
NPY_END_THREADS;
NpyIter_Deallocate(iter);
- return PyErr_Occurred() ? -1 : nonzero_count;
+ return nonzero_count;
}
/*NUMPY_API
diff --git a/numpy/core/src/multiarray/sequence.c b/numpy/core/src/multiarray/sequence.c
index 55b72c198..4769bdad9 100644
--- a/numpy/core/src/multiarray/sequence.c
+++ b/numpy/core/src/multiarray/sequence.c
@@ -15,9 +15,7 @@
#include "mapping.h"
#include "sequence.h"
-
-static int
-array_any_nonzero(PyArrayObject *mp);
+#include "calculation.h"
/*************************************************************************
**************** Implement Sequence Protocol **************************
@@ -32,16 +30,18 @@ array_contains(PyArrayObject *self, PyObject *el)
{
/* equivalent to (self == el).any() */
- PyObject *res;
int ret;
+ PyObject *res, *any;
res = PyArray_EnsureAnyArray(PyObject_RichCompare((PyObject *)self,
el, Py_EQ));
if (res == NULL) {
return -1;
}
- ret = array_any_nonzero((PyArrayObject *)res);
+ any = PyArray_Any((PyArrayObject *)res, NPY_MAXDIMS, NULL);
Py_DECREF(res);
+ ret = PyObject_IsTrue(any);
+ Py_DECREF(any);
return ret;
}
@@ -61,30 +61,3 @@ NPY_NO_EXPORT PySequenceMethods array_as_sequence = {
/****************** End of Sequence Protocol ****************************/
-/*
- * Helpers
- */
-
-/* Array evaluates as "TRUE" if any of the elements are non-zero*/
-static int
-array_any_nonzero(PyArrayObject *arr)
-{
- npy_intp counter;
- PyArrayIterObject *it;
- npy_bool anyTRUE = NPY_FALSE;
-
- it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)arr);
- if (it == NULL) {
- return anyTRUE;
- }
- counter = it->size;
- while (counter--) {
- if (PyArray_DESCR(arr)->f->nonzero(it->dataptr, arr)) {
- anyTRUE = NPY_TRUE;
- break;
- }
- PyArray_ITER_NEXT(it);
- }
- Py_DECREF(it);
- return anyTRUE;
-}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 88456a816..ea5421ec2 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2839,6 +2839,10 @@ class TestMethods(object):
e = np.array(['1+1j'], 'U')
assert_raises(TypeError, complex, e)
+class TestCequenceMethods(object):
+ def test_array_contains(self):
+ assert_(4.0 in np.arange(16.).reshape(4,4))
+ assert_(20.0 not in np.arange(16.).reshape(4,4))
class TestBinop(object):
def test_inplace(self):
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index e8c637179..b53c4338e 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1120,6 +1120,19 @@ class TestNonzero(object):
assert_equal(m.nonzero(), tgt)
+ def test_nonzero_invalid_object(self):
+ # gh-9295
+ a = np.array([np.array([1, 2]), 3])
+ assert_raises(ValueError, np.nonzero, a)
+
+ class BoolErrors:
+ def __bool__(self):
+ raise ValueError("Not allowed")
+ def __nonzero__(self):
+ raise ValueError("Not allowed")
+
+ assert_raises(ValueError, np.nonzero, np.array([BoolErrors()]))
+
class TestIndex(object):
def test_boolean(self):