diff options
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 416 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
2 files changed, 253 insertions, 184 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index cd0ae1680..7016760b3 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -787,255 +787,303 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, * over all but the desired sorting axis. */ static int -_new_sortlike(PyArrayObject *op, int axis, - NPY_SORTKIND swhich, - PyArray_PartitionFunc * part, - NPY_SELECTKIND pwhich, - npy_intp * kth, npy_intp nkth) +_new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort, + PyArray_PartitionFunc *part, npy_intp *kth, npy_intp nkth) { + npy_intp N = PyArray_DIM(op, axis); + npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op); + npy_intp astride = PyArray_STRIDE(op, axis); + int swap = PyArray_ISBYTESWAPPED(op); + int needcopy = !PyArray_ISALIGNED(op) || swap || astride != elsize; + int hasrefs = PyDataType_REFCHK(PyArray_DESCR(op)); + + PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn; + char *buffer = NULL; + PyArrayIterObject *it; - int needcopy = 0, swap; - npy_intp N, size; - int elsize; - npy_intp astride; - PyArray_SortFunc *sort = NULL; + npy_intp size; + + int ret = -1; NPY_BEGIN_THREADS_DEF; + /* Check if there is any sorting to do */ + if (N <= 1) { + return 0; + } + it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis); - swap = !PyArray_ISNOTSWAPPED(op); if (it == NULL) { return -1; } + size = it->size; NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op)); - if (part == NULL) { - sort = PyArray_DESCR(op)->f->sort[swhich]; - } - - size = it->size; - N = PyArray_DIMS(op)[axis]; - elsize = PyArray_DESCR(op)->elsize; - astride = PyArray_STRIDES(op)[axis]; - needcopy = !(PyArray_FLAGS(op) & NPY_ARRAY_ALIGNED) || - (astride != (npy_intp) elsize) || swap; if (needcopy) { - char *buffer = PyDataMem_NEW(N*elsize); + buffer = PyDataMem_NEW(N * elsize); if (buffer == NULL) { goto fail; } + } - while (size--) { - _unaligned_strided_byte_copy(buffer, (npy_intp) elsize, it->dataptr, - astride, N, elsize); - if (swap) { - _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize); - } - if (part == NULL) { - if (sort(buffer, N, op) < 0) { - PyDataMem_FREE(buffer); - goto fail; + while (size--) { + char *bufptr = it->dataptr; + + if (needcopy) { + if (hasrefs) { + /* + * For dtype's with objects, copyswapn Py_XINCREF's src + * and Py_XDECREF's dst. This would crash if called on + * an unitialized buffer, or leak a reference to each + * object if initialized. + * + * So, first do the copy with no refcounting... + */ + _unaligned_strided_byte_copy(buffer, elsize, + it->dataptr, astride, N, elsize); + /* ...then swap in-place if needed */ + if (swap) { + copyswapn(buffer, elsize, NULL, 0, N, swap, op); } } else { - npy_intp pivots[NPY_MAX_PIVOT_STACK]; - npy_intp npiv = 0; - npy_intp i; - for (i = 0; i < nkth; i++) { - if (part(buffer, N, kth[i], pivots, &npiv, op) < 0) { - PyDataMem_FREE(buffer); - goto fail; - } - } + copyswapn(buffer, elsize, it->dataptr, astride, N, swap, op); } - if (swap) { - _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize); + bufptr = buffer; + } + /* + * TODO: If the input array is byte-swapped but contiguous and + * aligned, it could be swapped (and later unswapped) in-place + * rather than after copying to the buffer. Care would have to + * be taken to ensure that, if there is an error in the call to + * sort or part, the unswapping is still done before returning. + */ + + if (part == NULL) { + ret = sort(bufptr, N, op); +#if defined(NPY_PY3K) + /* Object comparisons may raise an exception in Python 3 */ + if (hasrefs && PyErr_Occurred()) { + ret = -1; + } +#endif + if (ret < 0) { + goto fail; } - _unaligned_strided_byte_copy(it->dataptr, astride, buffer, - (npy_intp) elsize, N, elsize); - PyArray_ITER_NEXT(it); } - PyDataMem_FREE(buffer); - } - else { - while (size--) { - if (part == NULL) { - if (sort(it->dataptr, N, op) < 0) { + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + npy_intp i; + for (i = 0; i < nkth; ++i) { + ret = part(bufptr, N, kth[i], pivots, &npiv, op); +#if defined(NPY_PY3K) + /* Object comparisons may raise an exception in Python 3 */ + if (hasrefs && PyErr_Occurred()) { + ret = -1; + } +#endif + if (ret < 0) { goto fail; } } - else { - npy_intp pivots[NPY_MAX_PIVOT_STACK]; - npy_intp npiv = 0; - npy_intp i; - for (i = 0; i < nkth; i++) { - if (part(it->dataptr, N, kth[i], pivots, &npiv, op) < 0) { - goto fail; - } + } + + if (needcopy) { + if (hasrefs) { + if (swap) { + copyswapn(buffer, elsize, NULL, 0, N, swap, op); } + _unaligned_strided_byte_copy(it->dataptr, astride, + buffer, elsize, N, elsize); + } + else { + copyswapn(it->dataptr, astride, buffer, elsize, N, swap, op); } - PyArray_ITER_NEXT(it); } + + PyArray_ITER_NEXT(it); } + +fail: + PyDataMem_FREE(buffer); NPY_END_THREADS_DESCR(PyArray_DESCR(op)); + if (ret < 0 && !PyErr_Occurred()) { + /* Out of memory during sorting or buffer creation */ + PyErr_NoMemory(); + } Py_DECREF(it); - return 0; - fail: - /* Out of memory during sorting or buffer creation */ - NPY_END_THREADS; - PyErr_NoMemory(); - Py_DECREF(it); - return -1; + return ret; } static PyObject* -_new_argsortlike(PyArrayObject *op, int axis, - NPY_SORTKIND swhich, - PyArray_ArgPartitionFunc * argpart, - NPY_SELECTKIND pwhich, - npy_intp * kth, npy_intp nkth) +_new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort, + PyArray_ArgPartitionFunc *argpart, + npy_intp *kth, npy_intp nkth) { + npy_intp N = PyArray_DIM(op, axis); + npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op); + npy_intp astride = PyArray_STRIDE(op, axis); + int swap = PyArray_ISBYTESWAPPED(op); + int needcopy = !PyArray_ISALIGNED(op) || swap || astride != elsize; + int hasrefs = PyDataType_REFCHK(PyArray_DESCR(op)); + int needidxbuffer; - PyArrayIterObject *it = NULL; - PyArrayIterObject *rit = NULL; - PyArrayObject *ret; - npy_intp N, size, i; - npy_intp astride, rstride, *iptr; - int elsize; - int needcopy = 0, swap; - PyArray_ArgSortFunc *argsort = NULL; + PyArray_CopySwapNFunc *copyswapn = PyArray_DESCR(op)->f->copyswapn; + char *valbuffer = NULL; + npy_intp *idxbuffer = NULL; + + PyArrayObject *rop; + npy_intp rstride; + + PyArrayIterObject *it, *rit; + npy_intp size; + + int ret = -1; NPY_BEGIN_THREADS_DEF; - ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), - PyArray_NDIM(op), - PyArray_DIMS(op), - NPY_INTP, - NULL, NULL, 0, 0, (PyObject *)op); - if (ret == NULL) { + rop = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), + PyArray_DIMS(op), NPY_INTP, + NULL, NULL, 0, 0, (PyObject *)op); + if (rop == NULL) { return NULL; } + rstride = PyArray_STRIDE(rop, axis); + needidxbuffer = rstride != sizeof(npy_intp); + + /* Check if there is any argsorting to do */ + if (N <= 1) { + memset(PyArray_DATA(rop), 0, PyArray_NBYTES(rop)); + return (PyObject *)rop; + } + it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis); - rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)ret, &axis); - if (rit == NULL || it == NULL) { + rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)rop, &axis); + if (it == NULL || rit == NULL) { goto fail; } - swap = !PyArray_ISNOTSWAPPED(op); + size = it->size; NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op)); - if (argpart == NULL) { - argsort = PyArray_DESCR(op)->f->argsort[swhich]; - } - else { - argpart = get_argpartition_func(PyArray_TYPE(op), pwhich); - } - size = it->size; - N = PyArray_DIMS(op)[axis]; - elsize = PyArray_DESCR(op)->elsize; - astride = PyArray_STRIDES(op)[axis]; - rstride = PyArray_STRIDE(ret,axis); - - needcopy = swap || !(PyArray_FLAGS(op) & NPY_ARRAY_ALIGNED) || - (astride != (npy_intp) elsize) || - (rstride != sizeof(npy_intp)); - if (needcopy) { - char *valbuffer, *indbuffer; - valbuffer = PyDataMem_NEW(N*elsize); + if (needcopy) { + valbuffer = PyDataMem_NEW(N * elsize); if (valbuffer == NULL) { goto fail; } - indbuffer = PyDataMem_NEW(N*sizeof(npy_intp)); - if (indbuffer == NULL) { - PyDataMem_FREE(valbuffer); + } + + if (needidxbuffer) { + idxbuffer = (npy_intp *)PyDataMem_NEW(N * sizeof(npy_intp)); + if (idxbuffer == NULL) { goto fail; } - while (size--) { - _unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize, it->dataptr, - astride, N, elsize); - if (swap) { - _strided_byte_swap(valbuffer, (npy_intp) elsize, N, elsize); - } - iptr = (npy_intp *)indbuffer; - for (i = 0; i < N; i++) { - *iptr++ = i; - } - if (argpart == NULL) { - if (argsort(valbuffer, (npy_intp *)indbuffer, N, op) < 0) { - PyDataMem_FREE(valbuffer); - PyDataMem_FREE(indbuffer); - goto fail; + } + + while (size--) { + char *valptr = it->dataptr; + npy_intp *idxptr = (npy_intp *)rit->dataptr; + npy_intp *iptr, i; + + if (needcopy) { + if (hasrefs) { + /* + * For dtype's with objects, copyswapn Py_XINCREF's src + * and Py_XDECREF's dst. This would crash if called on + * an unitialized valbuffer, or leak a reference to + * each object item if initialized. + * + * So, first do the copy with no refcounting... + */ + _unaligned_strided_byte_copy(valbuffer, elsize, + it->dataptr, astride, N, elsize); + /* ...then swap in-place if needed */ + if (swap) { + copyswapn(valbuffer, elsize, NULL, 0, N, swap, op); } } else { - npy_intp pivots[NPY_MAX_PIVOT_STACK]; - npy_intp npiv = 0; - npy_intp i; - for (i = 0; i < nkth; i++) { - if (argpart(valbuffer, (npy_intp *)indbuffer, - N, kth[i], pivots, &npiv, op) < 0) { - PyDataMem_FREE(valbuffer); - PyDataMem_FREE(indbuffer); - goto fail; - } - } + copyswapn(valbuffer, elsize, + it->dataptr, astride, N, swap, op); } - _unaligned_strided_byte_copy(rit->dataptr, rstride, indbuffer, - sizeof(npy_intp), N, sizeof(npy_intp)); - PyArray_ITER_NEXT(it); - PyArray_ITER_NEXT(rit); + valptr = valbuffer; } - PyDataMem_FREE(valbuffer); - PyDataMem_FREE(indbuffer); - } - else { - while (size--) { - iptr = (npy_intp *)rit->dataptr; - for (i = 0; i < N; i++) { - *iptr++ = i; + + if (needidxbuffer) { + idxptr = idxbuffer; + } + + iptr = idxptr; + for (i = 0; i < N; ++i) { + *iptr++ = i; + } + + if (argpart == NULL) { + ret = argsort(valptr, idxptr, N, op); +#if defined(NPY_PY3K) + /* Object comparisons may raise an exception in Python 3 */ + if (hasrefs && PyErr_Occurred()) { + ret = -1; } - if (argpart == NULL) { - if (argsort(it->dataptr, (npy_intp *)rit->dataptr, - N, op) < 0) { +#endif + if (ret < 0) { + goto fail; + } + } + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + + for (i = 0; i < nkth; ++i) { + ret = argpart(valptr, idxptr, N, kth[i], pivots, &npiv, op); +#if defined(NPY_PY3K) + /* Object comparisons may raise an exception in Python 3 */ + if (hasrefs && PyErr_Occurred()) { + ret = -1; + } +#endif + if (ret < 0) { goto fail; } } - else { - npy_intp pivots[NPY_MAX_PIVOT_STACK]; - npy_intp npiv = 0; - npy_intp i; - for (i = 0; i < nkth; i++) { - if (argpart(it->dataptr, (npy_intp *)rit->dataptr, - N, kth[i], pivots, &npiv, op) < 0) { - goto fail; - } - } + } + + if (needidxbuffer) { + char *rptr = rit->dataptr; + iptr = idxbuffer; + + for (i = 0; i < N; ++i) { + *(npy_intp *)rptr = *iptr++; + rptr += rstride; } - PyArray_ITER_NEXT(it); - PyArray_ITER_NEXT(rit); } + + PyArray_ITER_NEXT(it); + PyArray_ITER_NEXT(rit); } +fail: + PyDataMem_FREE(valbuffer); + PyDataMem_FREE(idxbuffer); NPY_END_THREADS_DESCR(PyArray_DESCR(op)); - - Py_DECREF(it); - Py_DECREF(rit); - return (PyObject *)ret; - - fail: - NPY_END_THREADS; - if (!PyErr_Occurred()) { - /* Out of memory during sorting or buffer creation */ - PyErr_NoMemory(); + if (ret < 0) { + if (!PyErr_Occurred()) { + /* Out of memory during sorting or buffer creation */ + PyErr_NoMemory(); + } + Py_XDECREF(rop); + rop = NULL; } - Py_DECREF(ret); Py_XDECREF(it); Py_XDECREF(rit); - return NULL; + + return (PyObject *)rop; } + /* Be sure to save this global_compare when necessary */ static PyArrayObject *global_obj; @@ -1126,7 +1174,8 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) /* Determine if we should use type-specific algorithm or not */ if (PyArray_DESCR(op)->f->sort[which] != NULL) { - return _new_sortlike(op, axis, which, NULL, 0, NULL, 0); + return _new_sortlike(op, axis, PyArray_DESCR(op)->f->sort[which], + NULL, NULL, 0); } if (PyArray_DESCR(op)->f->compare == NULL) { @@ -1293,8 +1342,8 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SEL if (kthrvl == NULL) return -1; - res = _new_sortlike(op, axis, 0, - part, which, + res = _new_sortlike(op, axis, NULL, + part, PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); Py_DECREF(kthrvl); @@ -1419,8 +1468,9 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) } /* Determine if we should use new algorithm or not */ if (PyArray_DESCR(op2)->f->argsort[which] != NULL) { - ret = (PyArrayObject *)_new_argsortlike(op2, axis, which, - NULL, 0, NULL, 0); + ret = (PyArrayObject *)_new_argsortlike(op2, axis, + PyArray_DESCR(op2)->f->argsort[which], + NULL, NULL, 0); Py_DECREF(op2); return (PyObject *)ret; } @@ -1557,8 +1607,8 @@ PyArray_ArgPartition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_ return NULL; } - ret = (PyArrayObject *)_new_argsortlike(op2, axis, 0, - argpart, which, + ret = (PyArrayObject *)_new_argsortlike(op2, axis, NULL, + argpart, PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); Py_DECREF(kthrvl); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 7ecb56cae..95e87bbb9 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -835,7 +835,7 @@ class TestMethods(TestCase): assert_equal(c, a, msg) # test complex sorts. These use the same code as the scalars - # but the compare fuction differs. + # but the compare function differs. ai = a*1j + 1 bi = b*1j + 1 for kind in ['q', 'm', 'h'] : @@ -857,6 +857,16 @@ class TestMethods(TestCase): c.sort(kind=kind) assert_equal(c, ai, msg) + # test sorting of complex arrays requiring byte-swapping, gh-5441 + for endianess in '<>': + for dt in np.typecodes['Complex']: + dtype = '{0}{1}'.format(endianess, dt) + arr = np.array([1+3.j, 2+2.j, 3+1.j], dtype=dt) + c = arr.copy() + c.sort() + msg = 'byte-swapped complex sort, dtype={0}'.format(dt) + assert_equal(c, arr, msg) + # test string sorts. s = 'aaaaaaaa' a = np.array([s + chr(i) for i in range(101)]) @@ -1035,6 +1045,15 @@ class TestMethods(TestCase): assert_equal(ai.copy().argsort(kind=kind), a, msg) assert_equal(bi.copy().argsort(kind=kind), b, msg) + # test argsort of complex arrays requiring byte-swapping, gh-5441 + for endianess in '<>': + for dt in np.typecodes['Complex']: + dtype = '{0}{1}'.format(endianess, dt) + arr = np.array([1+3.j, 2+2.j, 3+1.j], dtype=dt) + msg = 'byte-swapped complex argsort, dtype={0}'.format(dt) + assert_equal(arr.argsort(), + np.arange(len(arr), dtype=np.intp), msg) + # test string argsorts. s = 'aaaaaaaa' a = np.array([s + chr(i) for i in range(101)]) |