From b83dfabfe169f32533a84d76b7e35ae544cb15a0 Mon Sep 17 00:00:00 2001 From: Travis Oliphant Date: Sat, 31 Dec 2005 05:54:07 +0000 Subject: Added other sort methods (heap, merge) --- scipy/base/src/multiarraymodule.c | 95 ++++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 12 deletions(-) (limited to 'scipy/base/src/multiarraymodule.c') diff --git a/scipy/base/src/multiarraymodule.c b/scipy/base/src/multiarraymodule.c index af6acfc34..4299adfb2 100644 --- a/scipy/base/src/multiarraymodule.c +++ b/scipy/base/src/multiarraymodule.c @@ -1591,9 +1591,12 @@ _new_sort(PyArrayObject *op, int axis, PyArray_SORTKIND which) int elsize; intp astride; PyArray_SortFunc *sort; + BEGIN_THREADS_DEF it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, axis); if (it == NULL) return -1; + + BEGIN_THREADS sort = op->descr->f->sort[which]; size = it->size; N = op->dimensions[axis]; @@ -1606,21 +1609,32 @@ _new_sort(PyArrayObject *op, int axis, PyArray_SORTKIND which) char *buffer; buffer = PyDataMem_NEW(N*elsize); while (size--) { - _strided_copy(buffer, (intp) elsize, it->dataptr, astride, - N, elsize); - sort(buffer, N, elsize, op); - _strided_copy(it->dataptr, astride, buffer, (intp) elsize, - N, elsize); + _strided_copy(buffer, (intp) elsize, it->dataptr, + astride, N, elsize); + if (sort(buffer, N, op) < 0) { + PyDataMem_FREE(buffer); goto fail; + } + _strided_copy(it->dataptr, astride, buffer, + (intp) elsize, N, elsize); PyArray_ITER_NEXT(it); } PyDataMem_FREE(buffer); } else { while (size--) { - sort(it->dataptr, N, elsize, op); + if (sort(it->dataptr, N, op) < 0) goto fail; PyArray_ITER_NEXT(it); } } + + END_THREADS + + Py_DECREF(it); + return 0; + + fail: + END_THREADS + Py_DECREF(it); return 0; } @@ -1637,6 +1651,7 @@ _new_argsort(PyArrayObject *op, int axis, PyArray_SORTKIND which) int elsize; intp astride, rstride, *iptr; PyArray_ArgSortFunc *argsort; + BEGIN_THREADS_DEF ret = PyArray_New(op->ob_type, op->nd, op->dimensions, PyArray_INTP, @@ -1646,6 +1661,9 @@ _new_argsort(PyArrayObject *op, int axis, PyArray_SORTKIND which) it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, axis); rit = (PyArrayIterObject *)PyArray_IterAllButAxis(ret, axis); if (rit == NULL || it == NULL) goto fail; + + BEGIN_THREADS + argsort = op->descr->f->argsort[which]; size = it->size; N = op->dimensions[axis]; @@ -1661,13 +1679,15 @@ _new_argsort(PyArrayObject *op, int axis, PyArray_SORTKIND which) valbuffer = PyDataMem_NEW(N*(elsize+sizeof(intp))); indbuffer = valbuffer + (N*elsize); while (size--) { - _strided_copy(valbuffer, (intp) elsize, it->dataptr, astride, - N, elsize); + _strided_copy(valbuffer, (intp) elsize, it->dataptr, + astride, N, elsize); iptr = (intp *)indbuffer; for (i=0; idataptr, rstride, indbuffer, sizeof(intp), - N, sizeof(intp)); + if (argsort(valbuffer, (intp *)indbuffer, N, op) < 0) { + PyDataMem_FREE(valbuffer); goto fail; + } + _strided_copy(rit->dataptr, rstride, indbuffer, + sizeof(intp), N, sizeof(intp)); PyArray_ITER_NEXT(it); PyArray_ITER_NEXT(rit); } @@ -1677,17 +1697,23 @@ _new_argsort(PyArrayObject *op, int axis, PyArray_SORTKIND which) while (size--) { iptr = (intp *)rit->dataptr; for (i=0; idataptr, (intp *)rit->dataptr, N, elsize, op); + if (argsort(it->dataptr, (intp *)rit->dataptr, + N, op) < 0) goto fail; PyArray_ITER_NEXT(it); PyArray_ITER_NEXT(rit); } } + + END_THREADS Py_DECREF(it); Py_DECREF(rit); return ret; fail: + + END_THREADS + Py_DECREF(ret); Py_XDECREF(it); Py_XDECREF(rit); @@ -1765,6 +1791,8 @@ PyArray_Sort(PyArrayObject *op, int axis, PyArray_SORTKIND which) int i, n, m, elsize, orign; n = op->nd; + if ((n==0) || (PyArray_SIZE(op)==1)) return 0; + if (axis < 0) axis += n; if ((axis < 0) || (axis >= n)) { PyErr_Format(PyExc_ValueError, @@ -1851,6 +1879,16 @@ PyArray_ArgSort(PyArrayObject *op, int axis, PyArray_SORTKIND which) char *store_ptr; n = op->nd; + if ((n==0) || (PyArray_SIZE(op)==1)) { + ret = (PyArrayObject *)PyArray_New(op->ob_type, op->nd, + op->dimensions, + PyArray_INTP, + NULL, NULL, 0, 0, + (PyObject *)op); + if (ret == NULL) return NULL; + *((intp *)ret->data) = 0; + return (PyObject *)ret; + } if (axis < 0) axis += n; if ((axis < 0) || (axis >= n)) { PyErr_Format(PyExc_ValueError, @@ -3773,6 +3811,39 @@ PyArray_ByteorderConverter(PyObject *obj, char *endian) return PY_SUCCEED; } +/*MULTIARRAY_API + Convert object to sort kind +*/ +static int +PyArray_SortkindConverter(PyObject *obj, PyArray_SORTKIND *sortkind) +{ + char *str; + *sortkind = PyArray_QUICKSORT; + str = PyString_AsString(obj); + if (!str) return PY_FAIL; + if (strlen(str) < 1) { + PyErr_SetString(PyExc_ValueError, + "Sort kind string must be at least length 1"); + return PY_FAIL; + } + if (str[0] == 'q' || str[0] == 'Q') + *sortkind = PyArray_QUICKSORT; + else if (str[0] == 'h' || str[0] == 'H') + *sortkind = PyArray_HEAPSORT; + else if (str[0] == 'm' || str[0] == 'M') + *sortkind = PyArray_MERGESORT; + else if (str[0] == 't' || str[0] == 'T') + *sortkind = PyArray_TIMSORT; + else { + PyErr_Format(PyExc_ValueError, + "%s is an unrecognized kind of sort", + str); + return PY_FAIL; + } + return PY_SUCCEED; +} + + /* This function returns true if the two typecodes are equivalent (same basic kind and same itemsize). */ -- cgit v1.2.1