summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/item_selection.c416
-rw-r--r--numpy/core/tests/test_multiarray.py21
2 files changed, 253 insertions, 184 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index cd0ae1680..7016760b3 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -787,255 +787,303 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
* over all but the desired sorting axis.
*/
static int
-_new_sortlike(PyArrayObject *op, int axis,
- NPY_SORTKIND swhich,
- PyArray_PartitionFunc * part,
- NPY_SELECTKIND pwhich,
- npy_intp * kth, npy_intp nkth)
+_new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
+ PyArray_PartitionFunc *part, npy_intp *kth, npy_intp nkth)
{
+ npy_intp N = PyArray_DIM(op, axis);
+ npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
+ npy_intp astride = PyArray_STRIDE(op, axis);
+ int swap = PyArray_ISBYTESWAPPED(op);
+ int needcopy = !PyArray_ISALIGNED(op) || swap || astride != elsize;
+ int hasrefs = PyDataType_REFCHK(PyArray_DESCR(op));
+
+ PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn;
+ char *buffer = NULL;
+
PyArrayIterObject *it;
- int needcopy = 0, swap;
- npy_intp N, size;
- int elsize;
- npy_intp astride;
- PyArray_SortFunc *sort = NULL;
+ npy_intp size;
+
+ int ret = -1;
NPY_BEGIN_THREADS_DEF;
+ /* Check if there is any sorting to do */
+ if (N <= 1) {
+ return 0;
+ }
+
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
- swap = !PyArray_ISNOTSWAPPED(op);
if (it == NULL) {
return -1;
}
+ size = it->size;
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op));
- if (part == NULL) {
- sort = PyArray_DESCR(op)->f->sort[swhich];
- }
-
- size = it->size;
- N = PyArray_DIMS(op)[axis];
- elsize = PyArray_DESCR(op)->elsize;
- astride = PyArray_STRIDES(op)[axis];
- needcopy = !(PyArray_FLAGS(op) & NPY_ARRAY_ALIGNED) ||
- (astride != (npy_intp) elsize) || swap;
if (needcopy) {
- char *buffer = PyDataMem_NEW(N*elsize);
+ buffer = PyDataMem_NEW(N * elsize);
if (buffer == NULL) {
goto fail;
}
+ }
- while (size--) {
- _unaligned_strided_byte_copy(buffer, (npy_intp) elsize, it->dataptr,
- astride, N, elsize);
- if (swap) {
- _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize);
- }
- if (part == NULL) {
- if (sort(buffer, N, op) < 0) {
- PyDataMem_FREE(buffer);
- goto fail;
+ while (size--) {
+ char *bufptr = it->dataptr;
+
+ if (needcopy) {
+ if (hasrefs) {
+ /*
+ * For dtype's with objects, copyswapn Py_XINCREF's src
+ * and Py_XDECREF's dst. This would crash if called on
+ * an unitialized buffer, or leak a reference to each
+ * object if initialized.
+ *
+ * So, first do the copy with no refcounting...
+ */
+ _unaligned_strided_byte_copy(buffer, elsize,
+ it->dataptr, astride, N, elsize);
+ /* ...then swap in-place if needed */
+ if (swap) {
+ copyswapn(buffer, elsize, NULL, 0, N, swap, op);
}
}
else {
- npy_intp pivots[NPY_MAX_PIVOT_STACK];
- npy_intp npiv = 0;
- npy_intp i;
- for (i = 0; i < nkth; i++) {
- if (part(buffer, N, kth[i], pivots, &npiv, op) < 0) {
- PyDataMem_FREE(buffer);
- goto fail;
- }
- }
+ copyswapn(buffer, elsize, it->dataptr, astride, N, swap, op);
}
- if (swap) {
- _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize);
+ bufptr = buffer;
+ }
+ /*
+ * TODO: If the input array is byte-swapped but contiguous and
+ * aligned, it could be swapped (and later unswapped) in-place
+ * rather than after copying to the buffer. Care would have to
+ * be taken to ensure that, if there is an error in the call to
+ * sort or part, the unswapping is still done before returning.
+ */
+
+ if (part == NULL) {
+ ret = sort(bufptr, N, op);
+#if defined(NPY_PY3K)
+ /* Object comparisons may raise an exception in Python 3 */
+ if (hasrefs && PyErr_Occurred()) {
+ ret = -1;
+ }
+#endif
+ if (ret < 0) {
+ goto fail;
}
- _unaligned_strided_byte_copy(it->dataptr, astride, buffer,
- (npy_intp) elsize, N, elsize);
- PyArray_ITER_NEXT(it);
}
- PyDataMem_FREE(buffer);
- }
- else {
- while (size--) {
- if (part == NULL) {
- if (sort(it->dataptr, N, op) < 0) {
+ else {
+ npy_intp pivots[NPY_MAX_PIVOT_STACK];
+ npy_intp npiv = 0;
+ npy_intp i;
+ for (i = 0; i < nkth; ++i) {
+ ret = part(bufptr, N, kth[i], pivots, &npiv, op);
+#if defined(NPY_PY3K)
+ /* Object comparisons may raise an exception in Python 3 */
+ if (hasrefs && PyErr_Occurred()) {
+ ret = -1;
+ }
+#endif
+ if (ret < 0) {
goto fail;
}
}
- else {
- npy_intp pivots[NPY_MAX_PIVOT_STACK];
- npy_intp npiv = 0;
- npy_intp i;
- for (i = 0; i < nkth; i++) {
- if (part(it->dataptr, N, kth[i], pivots, &npiv, op) < 0) {
- goto fail;
- }
+ }
+
+ if (needcopy) {
+ if (hasrefs) {
+ if (swap) {
+ copyswapn(buffer, elsize, NULL, 0, N, swap, op);
}
+ _unaligned_strided_byte_copy(it->dataptr, astride,
+ buffer, elsize, N, elsize);
+ }
+ else {
+ copyswapn(it->dataptr, astride, buffer, elsize, N, swap, op);
}
- PyArray_ITER_NEXT(it);
}
+
+ PyArray_ITER_NEXT(it);
}
+
+fail:
+ PyDataMem_FREE(buffer);
NPY_END_THREADS_DESCR(PyArray_DESCR(op));
+ if (ret < 0 && !PyErr_Occurred()) {
+ /* Out of memory during sorting or buffer creation */
+ PyErr_NoMemory();
+ }
Py_DECREF(it);
- return 0;
- fail:
- /* Out of memory during sorting or buffer creation */
- NPY_END_THREADS;
- PyErr_NoMemory();
- Py_DECREF(it);
- return -1;
+ return ret;
}
static PyObject*
-_new_argsortlike(PyArrayObject *op, int axis,
- NPY_SORTKIND swhich,
- PyArray_ArgPartitionFunc * argpart,
- NPY_SELECTKIND pwhich,
- npy_intp * kth, npy_intp nkth)
+_new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
+ PyArray_ArgPartitionFunc *argpart,
+ npy_intp *kth, npy_intp nkth)
{
+ npy_intp N = PyArray_DIM(op, axis);
+ npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
+ npy_intp astride = PyArray_STRIDE(op, axis);
+ int swap = PyArray_ISBYTESWAPPED(op);
+ int needcopy = !PyArray_ISALIGNED(op) || swap || astride != elsize;
+ int hasrefs = PyDataType_REFCHK(PyArray_DESCR(op));
+ int needidxbuffer;
- PyArrayIterObject *it = NULL;
- PyArrayIterObject *rit = NULL;
- PyArrayObject *ret;
- npy_intp N, size, i;
- npy_intp astride, rstride, *iptr;
- int elsize;
- int needcopy = 0, swap;
- PyArray_ArgSortFunc *argsort = NULL;
+ PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn;
+ char *valbuffer = NULL;
+ npy_intp *idxbuffer = NULL;
+
+ PyArrayObject *rop;
+ npy_intp rstride;
+
+ PyArrayIterObject *it, *rit;
+ npy_intp size;
+
+ int ret = -1;
NPY_BEGIN_THREADS_DEF;
- ret = (PyArrayObject *)PyArray_New(Py_TYPE(op),
- PyArray_NDIM(op),
- PyArray_DIMS(op),
- NPY_INTP,
- NULL, NULL, 0, 0, (PyObject *)op);
- if (ret == NULL) {
+ rop = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op),
+ PyArray_DIMS(op), NPY_INTP,
+ NULL, NULL, 0, 0, (PyObject *)op);
+ if (rop == NULL) {
return NULL;
}
+ rstride = PyArray_STRIDE(rop, axis);
+ needidxbuffer = rstride != sizeof(npy_intp);
+
+ /* Check if there is any argsorting to do */
+ if (N <= 1) {
+ memset(PyArray_DATA(rop), 0, PyArray_NBYTES(rop));
+ return (PyObject *)rop;
+ }
+
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
- rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)ret, &axis);
- if (rit == NULL || it == NULL) {
+ rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)rop, &axis);
+ if (it == NULL || rit == NULL) {
goto fail;
}
- swap = !PyArray_ISNOTSWAPPED(op);
+ size = it->size;
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op));
- if (argpart == NULL) {
- argsort = PyArray_DESCR(op)->f->argsort[swhich];
- }
- else {
- argpart = get_argpartition_func(PyArray_TYPE(op), pwhich);
- }
- size = it->size;
- N = PyArray_DIMS(op)[axis];
- elsize = PyArray_DESCR(op)->elsize;
- astride = PyArray_STRIDES(op)[axis];
- rstride = PyArray_STRIDE(ret,axis);
-
- needcopy = swap || !(PyArray_FLAGS(op) & NPY_ARRAY_ALIGNED) ||
- (astride != (npy_intp) elsize) ||
- (rstride != sizeof(npy_intp));
- if (needcopy) {
- char *valbuffer, *indbuffer;
- valbuffer = PyDataMem_NEW(N*elsize);
+ if (needcopy) {
+ valbuffer = PyDataMem_NEW(N * elsize);
if (valbuffer == NULL) {
goto fail;
}
- indbuffer = PyDataMem_NEW(N*sizeof(npy_intp));
- if (indbuffer == NULL) {
- PyDataMem_FREE(valbuffer);
+ }
+
+ if (needidxbuffer) {
+ idxbuffer = (npy_intp *)PyDataMem_NEW(N * sizeof(npy_intp));
+ if (idxbuffer == NULL) {
goto fail;
}
- while (size--) {
- _unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize, it->dataptr,
- astride, N, elsize);
- if (swap) {
- _strided_byte_swap(valbuffer, (npy_intp) elsize, N, elsize);
- }
- iptr = (npy_intp *)indbuffer;
- for (i = 0; i < N; i++) {
- *iptr++ = i;
- }
- if (argpart == NULL) {
- if (argsort(valbuffer, (npy_intp *)indbuffer, N, op) < 0) {
- PyDataMem_FREE(valbuffer);
- PyDataMem_FREE(indbuffer);
- goto fail;
+ }
+
+ while (size--) {
+ char *valptr = it->dataptr;
+ npy_intp *idxptr = (npy_intp *)rit->dataptr;
+ npy_intp *iptr, i;
+
+ if (needcopy) {
+ if (hasrefs) {
+ /*
+ * For dtype's with objects, copyswapn Py_XINCREF's src
+ * and Py_XDECREF's dst. This would crash if called on
+ * an unitialized valbuffer, or leak a reference to
+ * each object item if initialized.
+ *
+ * So, first do the copy with no refcounting...
+ */
+ _unaligned_strided_byte_copy(valbuffer, elsize,
+ it->dataptr, astride, N, elsize);
+ /* ...then swap in-place if needed */
+ if (swap) {
+ copyswapn(valbuffer, elsize, NULL, 0, N, swap, op);
}
}
else {
- npy_intp pivots[NPY_MAX_PIVOT_STACK];
- npy_intp npiv = 0;
- npy_intp i;
- for (i = 0; i < nkth; i++) {
- if (argpart(valbuffer, (npy_intp *)indbuffer,
- N, kth[i], pivots, &npiv, op) < 0) {
- PyDataMem_FREE(valbuffer);
- PyDataMem_FREE(indbuffer);
- goto fail;
- }
- }
+ copyswapn(valbuffer, elsize,
+ it->dataptr, astride, N, swap, op);
}
- _unaligned_strided_byte_copy(rit->dataptr, rstride, indbuffer,
- sizeof(npy_intp), N, sizeof(npy_intp));
- PyArray_ITER_NEXT(it);
- PyArray_ITER_NEXT(rit);
+ valptr = valbuffer;
}
- PyDataMem_FREE(valbuffer);
- PyDataMem_FREE(indbuffer);
- }
- else {
- while (size--) {
- iptr = (npy_intp *)rit->dataptr;
- for (i = 0; i < N; i++) {
- *iptr++ = i;
+
+ if (needidxbuffer) {
+ idxptr = idxbuffer;
+ }
+
+ iptr = idxptr;
+ for (i = 0; i < N; ++i) {
+ *iptr++ = i;
+ }
+
+ if (argpart == NULL) {
+ ret = argsort(valptr, idxptr, N, op);
+#if defined(NPY_PY3K)
+ /* Object comparisons may raise an exception in Python 3 */
+ if (hasrefs && PyErr_Occurred()) {
+ ret = -1;
}
- if (argpart == NULL) {
- if (argsort(it->dataptr, (npy_intp *)rit->dataptr,
- N, op) < 0) {
+#endif
+ if (ret < 0) {
+ goto fail;
+ }
+ }
+ else {
+ npy_intp pivots[NPY_MAX_PIVOT_STACK];
+ npy_intp npiv = 0;
+
+ for (i = 0; i < nkth; ++i) {
+ ret = argpart(valptr, idxptr, N, kth[i], pivots, &npiv, op);
+#if defined(NPY_PY3K)
+ /* Object comparisons may raise an exception in Python 3 */
+ if (hasrefs && PyErr_Occurred()) {
+ ret = -1;
+ }
+#endif
+ if (ret < 0) {
goto fail;
}
}
- else {
- npy_intp pivots[NPY_MAX_PIVOT_STACK];
- npy_intp npiv = 0;
- npy_intp i;
- for (i = 0; i < nkth; i++) {
- if (argpart(it->dataptr, (npy_intp *)rit->dataptr,
- N, kth[i], pivots, &npiv, op) < 0) {
- goto fail;
- }
- }
+ }
+
+ if (needidxbuffer) {
+ char *rptr = rit->dataptr;
+ iptr = idxbuffer;
+
+ for (i = 0; i < N; ++i) {
+ *(npy_intp *)rptr = *iptr++;
+ rptr += rstride;
}
- PyArray_ITER_NEXT(it);
- PyArray_ITER_NEXT(rit);
}
+
+ PyArray_ITER_NEXT(it);
+ PyArray_ITER_NEXT(rit);
}
+fail:
+ PyDataMem_FREE(valbuffer);
+ PyDataMem_FREE(idxbuffer);
NPY_END_THREADS_DESCR(PyArray_DESCR(op));
-
- Py_DECREF(it);
- Py_DECREF(rit);
- return (PyObject *)ret;
-
- fail:
- NPY_END_THREADS;
- if (!PyErr_Occurred()) {
- /* Out of memory during sorting or buffer creation */
- PyErr_NoMemory();
+ if (ret < 0) {
+ if (!PyErr_Occurred()) {
+ /* Out of memory during sorting or buffer creation */
+ PyErr_NoMemory();
+ }
+ Py_XDECREF(rop);
+ rop = NULL;
}
- Py_DECREF(ret);
Py_XDECREF(it);
Py_XDECREF(rit);
- return NULL;
+
+ return (PyObject *)rop;
}
+
/* Be sure to save this global_compare when necessary */
static PyArrayObject *global_obj;
@@ -1126,7 +1174,8 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which)
/* Determine if we should use type-specific algorithm or not */
if (PyArray_DESCR(op)->f->sort[which] != NULL) {
- return _new_sortlike(op, axis, which, NULL, 0, NULL, 0);
+ return _new_sortlike(op, axis, PyArray_DESCR(op)->f->sort[which],
+ NULL, NULL, 0);
}
if (PyArray_DESCR(op)->f->compare == NULL) {
@@ -1293,8 +1342,8 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SEL
if (kthrvl == NULL)
return -1;
- res = _new_sortlike(op, axis, 0,
- part, which,
+ res = _new_sortlike(op, axis, NULL,
+ part,
PyArray_DATA(kthrvl),
PyArray_SIZE(kthrvl));
Py_DECREF(kthrvl);
@@ -1419,8 +1468,9 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which)
}
/* Determine if we should use new algorithm or not */
if (PyArray_DESCR(op2)->f->argsort[which] != NULL) {
- ret = (PyArrayObject *)_new_argsortlike(op2, axis, which,
- NULL, 0, NULL, 0);
+ ret = (PyArrayObject *)_new_argsortlike(op2, axis,
+ PyArray_DESCR(op2)->f->argsort[which],
+ NULL, NULL, 0);
Py_DECREF(op2);
return (PyObject *)ret;
}
@@ -1557,8 +1607,8 @@ PyArray_ArgPartition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_
return NULL;
}
- ret = (PyArrayObject *)_new_argsortlike(op2, axis, 0,
- argpart, which,
+ ret = (PyArrayObject *)_new_argsortlike(op2, axis, NULL,
+ argpart,
PyArray_DATA(kthrvl),
PyArray_SIZE(kthrvl));
Py_DECREF(kthrvl);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7ecb56cae..95e87bbb9 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -835,7 +835,7 @@ class TestMethods(TestCase):
assert_equal(c, a, msg)
# test complex sorts. These use the same code as the scalars
- # but the compare fuction differs.
+ # but the compare function differs.
ai = a*1j + 1
bi = b*1j + 1
for kind in ['q', 'm', 'h'] :
@@ -857,6 +857,16 @@ class TestMethods(TestCase):
c.sort(kind=kind)
assert_equal(c, ai, msg)
+ # test sorting of complex arrays requiring byte-swapping, gh-5441
+ for endianess in '<>':
+ for dt in np.typecodes['Complex']:
+ dtype = '{0}{1}'.format(endianess, dt)
+ arr = np.array([1+3.j, 2+2.j, 3+1.j], dtype=dt)
+ c = arr.copy()
+ c.sort()
+ msg = 'byte-swapped complex sort, dtype={0}'.format(dt)
+ assert_equal(c, arr, msg)
+
# test string sorts.
s = 'aaaaaaaa'
a = np.array([s + chr(i) for i in range(101)])
@@ -1035,6 +1045,15 @@ class TestMethods(TestCase):
assert_equal(ai.copy().argsort(kind=kind), a, msg)
assert_equal(bi.copy().argsort(kind=kind), b, msg)
+ # test argsort of complex arrays requiring byte-swapping, gh-5441
+ for endianess in '<>':
+ for dt in np.typecodes['Complex']:
+ dtype = '{0}{1}'.format(endianess, dt)
+ arr = np.array([1+3.j, 2+2.j, 3+1.j], dtype=dt)
+ msg = 'byte-swapped complex argsort, dtype={0}'.format(dt)
+ assert_equal(arr.argsort(),
+ np.arange(len(arr), dtype=np.intp), msg)
+
# test string argsorts.
s = 'aaaaaaaa'
a = np.array([s + chr(i) for i in range(101)])