summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-05-26 12:18:11 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-06-01 13:44:55 -0700
commit64edab687abce12c09b60ab78db2666738b37ef2 (patch)
tree7131fbf80266a87f512827080cf7f426ad7f13a6 /numpy/core
parent458b5bd38aa50d4c903ff1117ea811a29f774709 (diff)
downloadnumpy-64edab687abce12c09b60ab78db2666738b37ef2.tar.gz
MAINT: Collect together the special-casing of 0d non-zero into one place
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/item_selection.c29
1 files changed, 20 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index f4d2513ca..63c6487bf 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2209,6 +2209,22 @@ PyArray_Nonzero(PyArrayObject *self)
char **dataptr;
int is_empty = 0;
+
+ /* Special case - nonzero(zero_d) is nonzero(atleast1d(zero_d)) */
+ if (ndim == 0) {
+ static npy_intp const zero_dim_shape[1] = {1};
+ static npy_intp const zero_dim_strides[1] = {0};
+
+ PyArrayObject *self_1d = (PyArrayObject *)PyArray_NewFromDescrAndBase(
+ Py_TYPE(self), PyArray_DESCR(self),
+ 1, zero_dim_shape, zero_dim_strides, PyArray_BYTES(self),
+ PyArray_FLAGS(self), (PyObject *)self, (PyObject *)self);
+ if (self_1d == NULL) {
+ return NULL;
+ }
+ return PyArray_Nonzero(self_1d);
+ }
+
/*
* First count the number of non-zeros in 'self'.
*/
@@ -2219,7 +2235,7 @@ PyArray_Nonzero(PyArrayObject *self)
/* Allocate the result as a 2D array */
ret_dims[0] = nonzero_count;
- ret_dims[1] = (ndim == 0) ? 1 : ndim;
+ ret_dims[1] = ndim;
ret = (PyArrayObject *)PyArray_NewFromDescr(
&PyArray_Type, PyArray_DescrFromType(NPY_INTP),
2, ret_dims, NULL, NULL,
@@ -2229,11 +2245,11 @@ PyArray_Nonzero(PyArrayObject *self)
}
/* If it's a one-dimensional result, don't use an iterator */
- if (ndim <= 1) {
+ if (ndim == 1) {
npy_intp * multi_index = (npy_intp *)PyArray_DATA(ret);
char * data = PyArray_BYTES(self);
- npy_intp stride = (ndim == 0) ? 0 : PyArray_STRIDE(self, 0);
- npy_intp count = (ndim == 0) ? 1 : PyArray_DIM(self, 0);
+ npy_intp stride = PyArray_STRIDE(self, 0);
+ npy_intp count = PyArray_DIM(self, 0);
NPY_BEGIN_THREADS_DEF;
/* nothing to do */
@@ -2351,11 +2367,6 @@ PyArray_Nonzero(PyArrayObject *self)
NpyIter_Deallocate(iter);
finish:
- /* Treat zero-dimensional as shape (1,) */
- if (ndim == 0) {
- ndim = 1;
- }
-
ret_tuple = PyTuple_New(ndim);
if (ret_tuple == NULL) {
Py_DECREF(ret);