diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-26 12:18:11 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-06-01 13:44:55 -0700 |
commit | 64edab687abce12c09b60ab78db2666738b37ef2 (patch) | |
tree | 7131fbf80266a87f512827080cf7f426ad7f13a6 /numpy/core | |
parent | 458b5bd38aa50d4c903ff1117ea811a29f774709 (diff) | |
download | numpy-64edab687abce12c09b60ab78db2666738b37ef2.tar.gz |
MAINT: Collect together the special-casing of 0d non-zero into one place
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index f4d2513ca..63c6487bf 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2209,6 +2209,22 @@ PyArray_Nonzero(PyArrayObject *self) char **dataptr; int is_empty = 0; + + /* Special case - nonzero(zero_d) is nonzero(atleast1d(zero_d)) */ + if (ndim == 0) { + static npy_intp const zero_dim_shape[1] = {1}; + static npy_intp const zero_dim_strides[1] = {0}; + + PyArrayObject *self_1d = (PyArrayObject *)PyArray_NewFromDescrAndBase( + Py_TYPE(self), PyArray_DESCR(self), + 1, zero_dim_shape, zero_dim_strides, PyArray_BYTES(self), + PyArray_FLAGS(self), (PyObject *)self, (PyObject *)self); + if (self_1d == NULL) { + return NULL; + } + return PyArray_Nonzero(self_1d); + } + /* * First count the number of non-zeros in 'self'. */ @@ -2219,7 +2235,7 @@ PyArray_Nonzero(PyArrayObject *self) /* Allocate the result as a 2D array */ ret_dims[0] = nonzero_count; - ret_dims[1] = (ndim == 0) ? 1 : ndim; + ret_dims[1] = ndim; ret = (PyArrayObject *)PyArray_NewFromDescr( &PyArray_Type, PyArray_DescrFromType(NPY_INTP), 2, ret_dims, NULL, NULL, @@ -2229,11 +2245,11 @@ PyArray_Nonzero(PyArrayObject *self) } /* If it's a one-dimensional result, don't use an iterator */ - if (ndim <= 1) { + if (ndim == 1) { npy_intp * multi_index = (npy_intp *)PyArray_DATA(ret); char * data = PyArray_BYTES(self); - npy_intp stride = (ndim == 0) ? 0 : PyArray_STRIDE(self, 0); - npy_intp count = (ndim == 0) ? 1 : PyArray_DIM(self, 0); + npy_intp stride = PyArray_STRIDE(self, 0); + npy_intp count = PyArray_DIM(self, 0); NPY_BEGIN_THREADS_DEF; /* nothing to do */ @@ -2351,11 +2367,6 @@ PyArray_Nonzero(PyArrayObject *self) NpyIter_Deallocate(iter); finish: - /* Treat zero-dimensional as shape (1,) */ - if (ndim == 0) { - ndim = 1; - } - ret_tuple = PyTuple_New(ndim); if (ret_tuple == NULL) { Py_DECREF(ret); |