diff options
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index f4d2513ca..c2fa7cbfe 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2207,7 +2207,21 @@ PyArray_Nonzero(PyArrayObject *self) NpyIter_IterNextFunc *iternext; NpyIter_GetMultiIndexFunc *get_multi_index; char **dataptr; - int is_empty = 0; + + /* Special case - nonzero(zero_d) is nonzero(atleast1d(zero_d)) */ + if (ndim == 0) { + static npy_intp const zero_dim_shape[1] = {1}; + static npy_intp const zero_dim_strides[1] = {0}; + + PyArrayObject *self_1d = (PyArrayObject *)PyArray_NewFromDescrAndBase( + Py_TYPE(self), PyArray_DESCR(self), + 1, zero_dim_shape, zero_dim_strides, PyArray_BYTES(self), + PyArray_FLAGS(self), (PyObject *)self, (PyObject *)self); + if (self_1d == NULL) { + return NULL; + } + return PyArray_Nonzero(self_1d); + } /* * First count the number of non-zeros in 'self'. @@ -2219,7 +2233,7 @@ PyArray_Nonzero(PyArrayObject *self) /* Allocate the result as a 2D array */ ret_dims[0] = nonzero_count; - ret_dims[1] = (ndim == 0) ? 1 : ndim; + ret_dims[1] = ndim; ret = (PyArrayObject *)PyArray_NewFromDescr( &PyArray_Type, PyArray_DescrFromType(NPY_INTP), 2, ret_dims, NULL, NULL, @@ -2229,11 +2243,11 @@ PyArray_Nonzero(PyArrayObject *self) } /* If it's a one-dimensional result, don't use an iterator */ - if (ndim <= 1) { + if (ndim == 1) { npy_intp * multi_index = (npy_intp *)PyArray_DATA(ret); char * data = PyArray_BYTES(self); - npy_intp stride = (ndim == 0) ? 0 : PyArray_STRIDE(self, 0); - npy_intp count = (ndim == 0) ? 1 : PyArray_DIM(self, 0); + npy_intp stride = PyArray_STRIDE(self, 0); + npy_intp count = PyArray_DIM(self, 0); NPY_BEGIN_THREADS_DEF; /* nothing to do */ @@ -2351,29 +2365,17 @@ PyArray_Nonzero(PyArrayObject *self) NpyIter_Deallocate(iter); finish: - /* Treat zero-dimensional as shape (1,) */ - if (ndim == 0) { - ndim = 1; - } - ret_tuple = PyTuple_New(ndim); if (ret_tuple == NULL) { Py_DECREF(ret); return NULL; } - for (i = 0; i < PyArray_NDIM(ret); ++i) { - if (PyArray_DIMS(ret)[i] == 0) { - is_empty = 1; - break; - } - } - /* Create views into ret, one for each dimension */ for (i = 0; i < ndim; ++i) { npy_intp stride = ndim * NPY_SIZEOF_INTP; /* the result is an empty array, the view must point to valid memory */ - npy_intp data_offset = is_empty ? 0 : i * NPY_SIZEOF_INTP; + npy_intp data_offset = nonzero_count == 0 ? 0 : i * NPY_SIZEOF_INTP; PyArrayObject *view = (PyArrayObject *)PyArray_NewFromDescrAndBase( Py_TYPE(ret), PyArray_DescrFromType(NPY_INTP), |