diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-08-12 06:28:41 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-08-12 06:28:41 -0700 |
commit | f8efcb6f191621e7c59f5385c85aeaa83be3669d (patch) | |
tree | d0029c507a53d4faedc3be104c8f8200eb21e157 /numpy/core | |
parent | 028007e05fd62d37fc98daf91ae0aafc1ea95cb8 (diff) | |
parent | 4d9cd695486fa095c6bff3238341a85cbdb47d0e (diff) | |
download | numpy-f8efcb6f191621e7c59f5385c85aeaa83be3669d.tar.gz |
Merge pull request #3360 from juliantaylor/selection-algo
add quickselect algorithm and expose it via partition
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/bento.info | 1 | ||||
-rw-r--r-- | numpy/core/code_generators/numpy_api.py | 4 | ||||
-rw-r--r-- | numpy/core/fromnumeric.py | 147 | ||||
-rw-r--r-- | numpy/core/include/numpy/ndarraytypes.h | 12 | ||||
-rw-r--r-- | numpy/core/setup.py | 3 | ||||
-rw-r--r-- | numpy/core/src/multiarray/conversion_utils.c | 39 | ||||
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 419 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 136 | ||||
-rw-r--r-- | numpy/core/src/npysort/selection.c.src | 459 | ||||
-rw-r--r-- | numpy/core/src/private/npy_partition.h | 559 | ||||
-rw-r--r-- | numpy/core/src/private/npy_partition.h.src | 131 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 290 |
12 files changed, 2177 insertions, 23 deletions
diff --git a/numpy/core/bento.info b/numpy/core/bento.info index 9c8476eee..9fe2229ee 100644 --- a/numpy/core/bento.info +++ b/numpy/core/bento.info @@ -13,6 +13,7 @@ Library: src/npysort/quicksort.c.src, src/npysort/mergesort.c.src, src/npysort/heapsort.c.src + src/npysort/selection.c.src, Extension: multiarray Sources: src/multiarray/multiarraymodule_onefile.c diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py index 152d0f948..134d6199e 100644 --- a/numpy/core/code_generators/numpy_api.py +++ b/numpy/core/code_generators/numpy_api.py @@ -335,6 +335,10 @@ multiarray_funcs_api = { 'PyArray_MapIterSwapAxes': 293, 'PyArray_MapIterArray': 294, 'PyArray_MapIterNext': 295, + # End 1.7 API + 'PyArray_Partition': 296, + 'PyArray_ArgPartition': 297, + 'PyArray_SelectkindConverter': 298, } ufunc_types_api = { diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 78957e16e..87fb5d818 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -16,7 +16,8 @@ _dt_ = nt.sctype2char # functions that are now methods __all__ = ['take', 'reshape', 'choose', 'repeat', 'put', - 'swapaxes', 'transpose', 'sort', 'argsort', 'argmax', 'argmin', + 'swapaxes', 'transpose', 'sort', 'argsort', 'partition', 'argpartition', + 'argmax', 'argmin', 'searchsorted', 'alen', 'resize', 'diagonal', 'trace', 'ravel', 'nonzero', 'shape', 'compress', 'clip', 'sum', 'product', 'prod', 'sometrue', 'alltrue', @@ -534,6 +535,150 @@ def transpose(a, axes=None): return transpose(axes) +def partition(a, kth, axis=-1, kind='introselect', order=None): + """ + Return a partitioned copy of an array. + + Creates a copy of the array with its elements rearranges in such a way that + the value of the element in kth position is in the position it would be in + a sorted array. All elements smaller than the kth element are moved before + this element and all equal or greater are moved behind it. The ordering of + the elements in the two partitions is undefined. + + .. versionadded:: 1.8.0 + + Parameters + ---------- + a : array_like + Array to be sorted. + kth : int or sequence of ints + Element index to partition by. The kth value of the element will be in + its final sorted position and all smaller elements will be moved before + it and all equal or greater elements behind it. + The order all elements in the partitions is undefined. + If provided with a sequence of kth it will partition all elements + indexed by kth of them into their sorted position at once. + axis : int or None, optional + Axis along which to sort. If None, the array is flattened before + sorting. The default is -1, which sorts along the last axis. + kind : {'introselect'}, optional + Selection algorithm. Default is 'introselect'. + order : list, optional + When `a` is a structured array, this argument specifies which fields + to compare first, second, and so on. This list does not need to + include all of the fields. + + Returns + ------- + partitioned_array : ndarray + Array of the same type and shape as `a`. + + See Also + -------- + ndarray.partition : Method to sort an array in-place. + argpartition : Indirect partition. + + Notes + ----- + The various selection algorithms are characterized by their average speed, + worst case performance, work space size, and whether they are stable. A + stable sort keeps items with the same key in the same relative order. The + three available algorithms have the following properties: + + ================= ======= ============= ============ ======= + kind speed worst case work space stable + ================= ======= ============= ============ ======= + 'introselect' 1 O(n) 0 no + ================= ======= ============= ============ ======= + + All the partition algorithms make temporary copies of the data when + partitioning along any but the last axis. Consequently, partitioning + along the last axis is faster and uses less space than partitioning + along any other axis. + + The sort order for complex numbers is lexicographic. If both the real + and imaginary parts are non-nan then the order is determined by the + real parts except when they are equal, in which case the order is + determined by the imaginary parts. + + Examples + -------- + >>> a = np.array([3, 4, 2, 1]) + >>> np.partition(a, 3) + array([2, 1, 3, 4]) + + >>> np.partition(a, (1, 3)) + array([1, 2, 3, 4]) + + """ + if axis is None: + a = asanyarray(a).flatten() + axis = 0 + else: + a = asanyarray(a).copy() + a.partition(kth, axis=axis, kind=kind, order=order) + return a + + +def argpartition(a, kth, axis=-1, kind='introselect', order=None): + """ + Perform an indirect partition along the given axis using the algorithm + specified by the `kind` keyword. It returns an array of indices of the + same shape as `a` that index data along the given axis in partitioned + order. + + .. versionadded:: 1.8.0 + + Parameters + ---------- + a : array_like + Array to sort. + kth : int or sequence of ints + Element index to partition by. The kth element will be in its final + sorted position and all smaller elements will be moved before it and + all larger elements behind it. + The order all elements in the partitions is undefined. + If provided with a sequence of kth it will partition all of them into + their sorted position at once. + axis : int or None, optional + Axis along which to sort. The default is -1 (the last axis). If None, + the flattened array is used. + kind : {'introselect'}, optional + Selection algorithm. Default is 'introselect' + order : list, optional + When `a` is an array with fields defined, this argument specifies + which fields to compare first, second, etc. Not all fields need be + specified. + + Returns + ------- + index_array : ndarray, int + Array of indices that partition`a` along the specified axis. + In other words, ``a[index_array]`` yields a sorted `a`. + + See Also + -------- + partition : Describes partition algorithms used. + ndarray.partition : Inplace partition. + + Notes + ----- + See `partition` for notes on the different selection algorithms. + + Examples + -------- + One dimensional array: + + >>> x = np.array([3, 4, 2, 1]) + >>> x[np.argpartition(x, 3)] + array([2, 1, 3, 4]) + >>> x[np.argpartition(x, (1, 3))] + array([1, 2, 3, 4]) + + """ + return a.argpartition(kth, axis, kind=kind, order=order) + + def sort(a, axis=-1, kind='quicksort', order=None): """ Return a sorted copy of an array. diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index fc8dca950..fd64cd75f 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -156,6 +156,12 @@ typedef enum { typedef enum { + NPY_INTROSELECT=0, +} NPY_SELECTKIND; +#define NPY_NSELECTS (NPY_INTROSELECT + 1) + + +typedef enum { NPY_SEARCHLEFT=0, NPY_SEARCHRIGHT=1 } NPY_SEARCHSIDE; @@ -391,6 +397,12 @@ typedef int (PyArray_FillFunc)(void *, npy_intp, void *); typedef int (PyArray_SortFunc)(void *, npy_intp, void *); typedef int (PyArray_ArgSortFunc)(void *, npy_intp *, npy_intp, void *); +typedef int (PyArray_PartitionFunc)(void *, npy_intp, npy_intp, + npy_intp *, npy_intp *, + void *); +typedef int (PyArray_ArgPartitionFunc)(void *, npy_intp *, npy_intp, npy_intp, + npy_intp *, npy_intp *, + void *); typedef int (PyArray_FillWithScalarFunc)(void *, npy_intp, void *, void *); diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 0c58abfa7..03d26c2b6 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -693,7 +693,8 @@ def configuration(parent_package='',top_path=None): config.add_library('npysort', sources = [join('src', 'npysort', 'quicksort.c.src'), join('src', 'npysort', 'mergesort.c.src'), - join('src', 'npysort', 'heapsort.c.src')]) + join('src', 'npysort', 'heapsort.c.src'), + join('src', 'npysort', 'selection.c.src')]) ####################################################################### diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c index 79093467f..a17156e53 100644 --- a/numpy/core/src/multiarray/conversion_utils.c +++ b/numpy/core/src/multiarray/conversion_utils.c @@ -406,6 +406,45 @@ PyArray_SortkindConverter(PyObject *obj, NPY_SORTKIND *sortkind) } /*NUMPY_API + * Convert object to select kind + */ +NPY_NO_EXPORT int +PyArray_SelectkindConverter(PyObject *obj, NPY_SELECTKIND *selectkind) +{ + char *str; + PyObject *tmp = NULL; + + if (PyUnicode_Check(obj)) { + obj = tmp = PyUnicode_AsASCIIString(obj); + } + + *selectkind = NPY_INTROSELECT; + str = PyBytes_AsString(obj); + if (!str) { + Py_XDECREF(tmp); + return NPY_FAIL; + } + if (strlen(str) < 1) { + PyErr_SetString(PyExc_ValueError, + "Select kind string must be at least length 1"); + Py_XDECREF(tmp); + return NPY_FAIL; + } + if (strcmp(str, "introselect") == 0) { + *selectkind = NPY_INTROSELECT; + } + else { + PyErr_Format(PyExc_ValueError, + "%s is an unrecognized kind of select", + str); + Py_XDECREF(tmp); + return NPY_FAIL; + } + Py_XDECREF(tmp); + return NPY_SUCCEED; +} + +/*NUMPY_API * Convert object to searchsorted side */ NPY_NO_EXPORT int diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 9be9f7140..b2be4f464 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -20,6 +20,7 @@ #include "item_selection.h" #include "npy_sort.h" +#include "npy_partition.h" /*NUMPY_API * Take @@ -768,14 +769,19 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, * over all but the desired sorting axis. */ static int -_new_sort(PyArrayObject *op, int axis, NPY_SORTKIND which) +_new_sortlike(PyArrayObject *op, int axis, + NPY_SORTKIND swhich, + PyArray_PartitionFunc * part, + NPY_SELECTKIND pwhich, + npy_intp * kth, npy_intp nkth) { PyArrayIterObject *it; int needcopy = 0, swap; npy_intp N, size; int elsize; npy_intp astride; - PyArray_SortFunc *sort; + PyArray_SortFunc *sort = NULL; + NPY_BEGIN_THREADS_DEF; it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis); @@ -785,7 +791,10 @@ _new_sort(PyArrayObject *op, int axis, NPY_SORTKIND which) } NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op)); - sort = PyArray_DESCR(op)->f->sort[which]; + if (part == NULL) { + sort = PyArray_DESCR(op)->f->sort[swhich]; + } + size = it->size; N = PyArray_DIMS(op)[axis]; elsize = PyArray_DESCR(op)->elsize; @@ -805,9 +814,22 @@ _new_sort(PyArrayObject *op, int axis, NPY_SORTKIND which) if (swap) { _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize); } - if (sort(buffer, N, op) < 0) { - PyDataMem_FREE(buffer); - goto fail; + if (part == NULL) { + if (sort(buffer, N, op) < 0) { + PyDataMem_FREE(buffer); + goto fail; + } + } + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + npy_intp i; + for (i = 0; i < nkth; i++) { + if (part(buffer, N, kth[i], pivots, &npiv, op) < 0) { + PyDataMem_FREE(buffer); + goto fail; + } + } } if (swap) { _strided_byte_swap(buffer, (npy_intp) elsize, N, elsize); @@ -820,8 +842,20 @@ _new_sort(PyArrayObject *op, int axis, NPY_SORTKIND which) } else { while (size--) { - if (sort(it->dataptr, N, op) < 0) { - goto fail; + if (part == NULL) { + if (sort(it->dataptr, N, op) < 0) { + goto fail; + } + } + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + npy_intp i; + for (i = 0; i < nkth; i++) { + if (part(it->dataptr, N, kth[i], pivots, &npiv, op) < 0) { + goto fail; + } + } } PyArray_ITER_NEXT(it); } @@ -839,7 +873,11 @@ _new_sort(PyArrayObject *op, int axis, NPY_SORTKIND which) } static PyObject* -_new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) +_new_argsortlike(PyArrayObject *op, int axis, + NPY_SORTKIND swhich, + PyArray_ArgPartitionFunc * argpart, + NPY_SELECTKIND pwhich, + npy_intp * kth, npy_intp nkth) { PyArrayIterObject *it = NULL; @@ -849,7 +887,8 @@ _new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) npy_intp astride, rstride, *iptr; int elsize; int needcopy = 0, swap; - PyArray_ArgSortFunc *argsort; + PyArray_ArgSortFunc *argsort = NULL; + NPY_BEGIN_THREADS_DEF; ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), @@ -868,7 +907,12 @@ _new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) swap = !PyArray_ISNOTSWAPPED(op); NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op)); - argsort = PyArray_DESCR(op)->f->argsort[which]; + if (argpart == NULL) { + argsort = PyArray_DESCR(op)->f->argsort[swhich]; + } + else { + argpart = get_argpartition_func(PyArray_TYPE(op), pwhich); + } size = it->size; N = PyArray_DIMS(op)[axis]; elsize = PyArray_DESCR(op)->elsize; @@ -900,10 +944,25 @@ _new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) for (i = 0; i < N; i++) { *iptr++ = i; } - if (argsort(valbuffer, (npy_intp *)indbuffer, N, op) < 0) { - PyDataMem_FREE(valbuffer); - PyDataMem_FREE(indbuffer); - goto fail; + if (argpart == NULL) { + if (argsort(valbuffer, (npy_intp *)indbuffer, N, op) < 0) { + PyDataMem_FREE(valbuffer); + PyDataMem_FREE(indbuffer); + goto fail; + } + } + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + npy_intp i; + for (i = 0; i < nkth; i++) { + if (argpart(valbuffer, (npy_intp *)indbuffer, + N, kth[i], pivots, &npiv, op) < 0) { + PyDataMem_FREE(valbuffer); + PyDataMem_FREE(indbuffer); + goto fail; + } + } } _unaligned_strided_byte_copy(rit->dataptr, rstride, indbuffer, sizeof(npy_intp), N, sizeof(npy_intp)); @@ -919,8 +978,22 @@ _new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) for (i = 0; i < N; i++) { *iptr++ = i; } - if (argsort(it->dataptr, (npy_intp *)rit->dataptr, N, op) < 0) { - goto fail; + if (argpart == NULL) { + if (argsort(it->dataptr, (npy_intp *)rit->dataptr, + N, op) < 0) { + goto fail; + } + } + else { + npy_intp pivots[NPY_MAX_PIVOT_STACK]; + npy_intp npiv = 0; + npy_intp i; + for (i = 0; i < nkth; i++) { + if (argpart(it->dataptr, (npy_intp *)rit->dataptr, + N, kth[i], pivots, &npiv, op) < 0) { + goto fail; + } + } } PyArray_ITER_NEXT(it); PyArray_ITER_NEXT(rit); @@ -945,7 +1018,6 @@ _new_argsort(PyArrayObject *op, int axis, NPY_SORTKIND which) return NULL; } - /* Be sure to save this global_compare when necessary */ static PyArrayObject *global_obj; @@ -1036,7 +1108,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) /* Determine if we should use type-specific algorithm or not */ if (PyArray_DESCR(op)->f->sort[which] != NULL) { - return _new_sort(op, axis, which); + return _new_sortlike(op, axis, which, NULL, 0, NULL, 0); } if (PyArray_DESCR(op)->f->compare == NULL) { @@ -1113,6 +1185,175 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) } +/* + * make kth array positive, ravel and sort it + */ +static PyArrayObject * +partition_prep_kth_array(PyArrayObject * ktharray, + PyArrayObject * op, + int axis) +{ + const npy_intp * shape = PyArray_SHAPE(op); + PyArrayObject * kthrvl; + npy_intp * kth; + npy_intp nkth, i; + + if (!PyArray_CanCastSafely(PyArray_TYPE(ktharray), NPY_INTP)) { + if (DEPRECATE("Calling partition with a non integer index" + " will result in an error in the future") < 0) { + return NULL; + } + } + + if (PyArray_NDIM(ktharray) > 1) { + PyErr_Format(PyExc_ValueError, "kth array must have dimension <= 1"); + return NULL; + } + kthrvl = PyArray_Cast(ktharray, NPY_INTP); + + if (kthrvl == NULL) + return NULL; + + kth = PyArray_DATA(kthrvl); + nkth = PyArray_SIZE(kthrvl); + + for (i = 0; i < nkth; i++) { + if (kth[i] < 0) { + kth[i] += shape[axis]; + } + if (PyArray_SIZE(op) != 0 && ((kth[i] < 0) || (kth[i] >= shape[axis]))) { + PyErr_Format(PyExc_ValueError, "kth(=%zd) out of bounds (%zd)", + kth[i], shape[axis]); + Py_XDECREF(kthrvl); + return NULL; + } + } + + /* + * sort the array of kths so the partitions will + * not trample on each other + */ + PyArray_Sort(kthrvl, -1, NPY_QUICKSORT); + + return kthrvl; +} + + + +/*NUMPY_API + * Partition an array in-place + */ +NPY_NO_EXPORT int +PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SELECTKIND which) +{ + PyArrayObject *ap = NULL, *store_arr = NULL; + char *ip; + npy_intp i, n, m; + int elsize, orign; + int res = 0; + int axis_orig = axis; + int (*sort)(void *, size_t, size_t, npy_comparator); + PyArray_PartitionFunc * part = get_partition_func(PyArray_TYPE(op), which); + + n = PyArray_NDIM(op); + if ((n == 0)) { + return 0; + } + if (axis < 0) { + axis += n; + } + if ((axis < 0) || (axis >= n)) { + PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig); + return -1; + } + if (PyArray_FailUnlessWriteable(op, "sort array") < 0) { + return -1; + } + + if (part) { + PyArrayObject * kthrvl = partition_prep_kth_array(ktharray, op, axis); + if (kthrvl == NULL) + return -1; + + res = _new_sortlike(op, axis, 0, + part, which, + PyArray_DATA(kthrvl), + PyArray_SIZE(kthrvl)); + Py_DECREF(kthrvl); + return res; + } + + + if (PyArray_DESCR(op)->f->compare == NULL) { + PyErr_SetString(PyExc_TypeError, + "type does not have compare function"); + return -1; + } + + SWAPAXES2(op); + + /* select not implemented, use quicksort, slower but equivalent */ + switch (which) { + case NPY_INTROSELECT : + sort = npy_quicksort; + break; + default: + PyErr_SetString(PyExc_TypeError, + "requested sort kind is not supported"); + goto fail; + } + + ap = (PyArrayObject *)PyArray_FromAny((PyObject *)op, + NULL, 1, 0, + NPY_ARRAY_DEFAULT | NPY_ARRAY_UPDATEIFCOPY, NULL); + if (ap == NULL) { + goto fail; + } + elsize = PyArray_DESCR(ap)->elsize; + m = PyArray_DIMS(ap)[PyArray_NDIM(ap)-1]; + if (m == 0) { + goto finish; + } + n = PyArray_SIZE(ap)/m; + + /* Store global -- allows re-entry -- restore before leaving*/ + store_arr = global_obj; + global_obj = ap; + /* we don't need to care about kth here as we are using a full sort */ + for (ip = PyArray_DATA(ap), i = 0; i < n; i++, ip += elsize*m) { + res = sort(ip, m, elsize, sortCompare); + if (res < 0) { + break; + } + } + global_obj = store_arr; + + if (PyErr_Occurred()) { + goto fail; + } + else if (res == -NPY_ENOMEM) { + PyErr_NoMemory(); + goto fail; + } + else if (res == -NPY_ECOMP) { + PyErr_SetString(PyExc_TypeError, + "sort comparison failed"); + goto fail; + } + + + finish: + Py_DECREF(ap); /* Should update op if needed */ + SWAPBACK2(op); + return 0; + + fail: + Py_XDECREF(ap); + SWAPBACK2(op); + return -1; +} + + static char *global_data; static int @@ -1160,7 +1401,8 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) } /* Determine if we should use new algorithm or not */ if (PyArray_DESCR(op2)->f->argsort[which] != NULL) { - ret = (PyArrayObject *)_new_argsort(op2, axis, which); + ret = (PyArrayObject *)_new_argsortlike(op2, axis, which, + NULL, 0, NULL, 0); Py_DECREF(op2); return (PyObject *)ret; } @@ -1255,6 +1497,143 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) /*NUMPY_API + * ArgPartition an array + */ +NPY_NO_EXPORT PyObject * +PyArray_ArgPartition(PyArrayObject *op, PyArrayObject * ktharray, int axis, NPY_SELECTKIND which) +{ + PyArrayObject *ap = NULL, *ret = NULL, *store, *op2; + npy_intp *ip; + npy_intp i, j, n, m, orign; + int argsort_elsize; + char *store_ptr; + int res = 0; + int (*sort)(void *, size_t, size_t, npy_comparator); + PyArray_ArgPartitionFunc * argpart = + get_argpartition_func(PyArray_TYPE(op), which); + + n = PyArray_NDIM(op); + if ((n == 0) || (PyArray_SIZE(op) == 1)) { + ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), + PyArray_DIMS(op), + NPY_INTP, + NULL, NULL, 0, 0, + (PyObject *)op); + if (ret == NULL) { + return NULL; + } + *((npy_intp *)PyArray_DATA(ret)) = 0; + return (PyObject *)ret; + } + + /* Creates new reference op2 */ + if ((op2=(PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) { + return NULL; + } + + /* Determine if we should use new algorithm or not */ + if (argpart) { + PyArrayObject * kthrvl = partition_prep_kth_array(ktharray, op2, axis); + if (kthrvl == NULL) { + Py_DECREF(op2); + return NULL; + } + + ret = (PyArrayObject *)_new_argsortlike(op2, axis, 0, + argpart, which, + PyArray_DATA(kthrvl), + PyArray_SIZE(kthrvl)); + Py_DECREF(kthrvl); + Py_DECREF(op2); + return (PyObject *)ret; + } + + if (PyArray_DESCR(op2)->f->compare == NULL) { + PyErr_SetString(PyExc_TypeError, + "type does not have compare function"); + Py_DECREF(op2); + op = NULL; + goto fail; + } + + /* select not implemented, use quicksort, slower but equivalent */ + switch (which) { + case NPY_INTROSELECT : + sort = npy_quicksort; + break; + default: + PyErr_SetString(PyExc_TypeError, + "requested sort kind is not supported"); + Py_DECREF(op2); + op = NULL; + goto fail; + } + + /* ap will contain the reference to op2 */ + SWAPAXES(ap, op2); + op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap, + NPY_NOTYPE, + 1, 0); + Py_DECREF(ap); + if (op == NULL) { + return NULL; + } + ret = (PyArrayObject *)PyArray_New(Py_TYPE(op), PyArray_NDIM(op), + PyArray_DIMS(op), NPY_INTP, + NULL, NULL, 0, 0, (PyObject *)op); + if (ret == NULL) { + goto fail; + } + ip = (npy_intp *)PyArray_DATA(ret); + argsort_elsize = PyArray_DESCR(op)->elsize; + m = PyArray_DIMS(op)[PyArray_NDIM(op)-1]; + if (m == 0) { + goto finish; + } + n = PyArray_SIZE(op)/m; + store_ptr = global_data; + global_data = PyArray_DATA(op); + store = global_obj; + global_obj = op; + /* we don't need to care about kth here as we are using a full sort */ + for (i = 0; i < n; i++, ip += m, global_data += m*argsort_elsize) { + for (j = 0; j < m; j++) { + ip[j] = j; + } + res = sort((char *)ip, m, sizeof(npy_intp), argsort_static_compare); + if (res < 0) { + break; + } + } + global_data = store_ptr; + global_obj = store; + + if (PyErr_Occurred()) { + goto fail; + } + else if (res == -NPY_ENOMEM) { + PyErr_NoMemory(); + goto fail; + } + else if (res == -NPY_ECOMP) { + PyErr_SetString(PyExc_TypeError, + "sort comparison failed"); + goto fail; + } + + finish: + Py_DECREF(op); + SWAPBACK(op, ret); + return (PyObject *)op; + + fail: + Py_XDECREF(op); + Py_XDECREF(ret); + return NULL; +} + + +/*NUMPY_API *LexSort an array providing indices that will sort a collection of arrays *lexicographically. The first key is sorted on first, followed by the second key *-- requires that arg"merge"sort is available for each sort_key diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index e5e21c98d..e4858eef7 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1180,6 +1180,75 @@ array_sort(PyArrayObject *self, PyObject *args, PyObject *kwds) } static PyObject * +array_partition(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + int axis=-1; + int val; + NPY_SELECTKIND sortkind = NPY_INTROSELECT; + PyObject *order = NULL; + PyArray_Descr *saved = NULL; + PyArray_Descr *newd; + static char *kwlist[] = {"kth", "axis", "kind", "order", NULL}; + PyArrayObject * ktharray; + PyObject * kthobj; + + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iO&O", kwlist, + &kthobj, + &axis, + PyArray_SelectkindConverter, &sortkind, + &order)) { + return NULL; + } + + if (order == Py_None) { + order = NULL; + } + if (order != NULL) { + PyObject *new_name; + PyObject *_numpy_internal; + saved = PyArray_DESCR(self); + if (!PyDataType_HASFIELDS(saved)) { + PyErr_SetString(PyExc_ValueError, "Cannot specify " \ + "order when the array has no fields."); + return NULL; + } + _numpy_internal = PyImport_ImportModule("numpy.core._internal"); + if (_numpy_internal == NULL) { + return NULL; + } + new_name = PyObject_CallMethod(_numpy_internal, "_newnames", + "OO", saved, order); + Py_DECREF(_numpy_internal); + if (new_name == NULL) { + return NULL; + } + newd = PyArray_DescrNew(saved); + Py_DECREF(newd->names); + newd->names = new_name; + ((PyArrayObject_fields *)self)->descr = newd; + } + + ktharray = (PyArrayObject *)PyArray_FromAny(kthobj, NULL, 0, 1, + NPY_ARRAY_DEFAULT, NULL); + if (ktharray == NULL) + return NULL; + + val = PyArray_Partition(self, ktharray, axis, sortkind); + Py_DECREF(ktharray); + + if (order != NULL) { + Py_XDECREF(PyArray_DESCR(self)); + ((PyArrayObject_fields *)self)->descr = saved; + } + if (val < 0) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +static PyObject * array_argsort(PyArrayObject *self, PyObject *args, PyObject *kwds) { int axis = -1; @@ -1229,6 +1298,67 @@ array_argsort(PyArrayObject *self, PyObject *args, PyObject *kwds) return PyArray_Return((PyArrayObject *)res); } + +static PyObject * +array_argpartition(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + int axis = -1; + NPY_SELECTKIND sortkind = NPY_INTROSELECT; + PyObject *order = NULL, *res; + PyArray_Descr *newd, *saved=NULL; + static char *kwlist[] = {"kth", "axis", "kind", "order", NULL}; + PyObject * kthobj; + PyArrayObject * ktharray; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O&O", kwlist, + &kthobj, + PyArray_AxisConverter, &axis, + PyArray_SelectkindConverter, &sortkind, + &order)) { + return NULL; + } + if (order == Py_None) { + order = NULL; + } + if (order != NULL) { + PyObject *new_name; + PyObject *_numpy_internal; + saved = PyArray_DESCR(self); + if (!PyDataType_HASFIELDS(saved)) { + PyErr_SetString(PyExc_ValueError, "Cannot specify " + "order when the array has no fields."); + return NULL; + } + _numpy_internal = PyImport_ImportModule("numpy.core._internal"); + if (_numpy_internal == NULL) { + return NULL; + } + new_name = PyObject_CallMethod(_numpy_internal, "_newnames", + "OO", saved, order); + Py_DECREF(_numpy_internal); + if (new_name == NULL) { + return NULL; + } + newd = PyArray_DescrNew(saved); + newd->names = new_name; + ((PyArrayObject_fields *)self)->descr = newd; + } + + ktharray = (PyArrayObject *)PyArray_FromAny(kthobj, NULL, 0, 1, + NPY_ARRAY_DEFAULT, NULL); + if (ktharray == NULL) + return NULL; + + res = PyArray_ArgPartition(self, ktharray, axis, sortkind); + Py_DECREF(ktharray); + + if (order != NULL) { + Py_XDECREF(PyArray_DESCR(self)); + ((PyArrayObject_fields *)self)->descr = saved; + } + return PyArray_Return((PyArrayObject *)res); +} + static PyObject * array_searchsorted(PyArrayObject *self, PyObject *args, PyObject *kwds) { @@ -2203,6 +2333,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { {"argmin", (PyCFunction)array_argmin, METH_VARARGS | METH_KEYWORDS, NULL}, + {"argpartition", + (PyCFunction)array_argpartition, + METH_VARARGS | METH_KEYWORDS, NULL}, {"argsort", (PyCFunction)array_argsort, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -2272,6 +2405,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { {"nonzero", (PyCFunction)array_nonzero, METH_VARARGS, NULL}, + {"partition", + (PyCFunction)array_partition, + METH_VARARGS | METH_KEYWORDS, NULL}, {"prod", (PyCFunction)array_prod, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/numpy/core/src/npysort/selection.c.src b/numpy/core/src/npysort/selection.c.src new file mode 100644 index 000000000..a41aa87f1 --- /dev/null +++ b/numpy/core/src/npysort/selection.c.src @@ -0,0 +1,459 @@ +/* -*- c -*- */ + +/* + * + * The code is loosely based on the quickselect from + * Nicolas Devillard - 1998 public domain + * http://ndevilla.free.fr/median/median/ + * + * Quick select with median of 3 pivot is usually the fastest, + * but the worst case scenario can be quadratic complexcity, + * e.g. np.roll(np.arange(x), x / 2) + * To avoid this if it recurses too much it falls back to the + * worst case linear median of median of group 5 pivot strategy. + */ + + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION + +#include "npy_sort.h" +#include "npysort_common.h" +#include "numpy/npy_math.h" +#include "npy_partition.h" +#include <stdlib.h> + +#define NOT_USED NPY_UNUSED(unused) + + +/* + ***************************************************************************** + ** NUMERIC SORTS ** + ***************************************************************************** + */ + + +static NPY_INLINE void store_pivot(npy_intp pivot, npy_intp kth, + npy_intp * pivots, npy_intp * npiv) +{ + if (pivots == NULL) { + return; + } + + /* + * If pivot is the requested kth store it, overwritting other pivots if + * required. This must be done so iterative partition can work without + * manually shifting lower data offset by kth each time + */ + if (pivot == kth && *npiv == NPY_MAX_PIVOT_STACK) { + pivots[*npiv - 1] = pivot; + } + /* + * we only need pivots larger than current kth, larger pivots are not + * useful as partitions on smaller kth would reorder the stored pivots + */ + else if (pivot >= kth && *npiv < NPY_MAX_PIVOT_STACK) { + pivots[*npiv] = pivot; + (*npiv) += 1; + } +} + +/**begin repeat + * + * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE# + * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, + * longlong, ulonglong, half, float, double, longdouble, + * cfloat, cdouble, clongdouble# + * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, + * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong, + * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat, + * npy_cdouble, npy_clongdouble# + */ + +static npy_intp +amedian_of_median5_@suff@(@type@ *v, npy_intp* tosort, const npy_intp num, + npy_intp * pivots, + npy_intp * npiv); + +static npy_intp +median_of_median5_@suff@(@type@ *v, const npy_intp num, + npy_intp * pivots, + npy_intp * npiv); + +/**begin repeat1 + * #name = , a# + * #idx = , tosort# + * #arg = 0, 1# + */ +#if @arg@ +/* helper macros to avoid duplication of direct/indirect selection */ +#define IDX(x) tosort[x] +#define SORTEE(x) tosort[x] +#define SWAP INTP_SWAP +#define MEDIAN3_SWAP(v, tosort, low, mid, high) \ + amedian3_swap_@suff@(v, tosort, low, mid, high) +#define MEDIAN5(v, tosort, subleft) \ + amedian5_@suff@(v, tosort + subleft) +#define UNGUARDED_PARTITION(v, tosort, pivot, ll, hh) \ + aunguarded_partition_@suff@(v, tosort, pivot, ll, hh) +#define INTROSELECT(v, tosort, num, kth, pivots, npiv) \ + aintroselect_@suff@(v, tosort, nmed, nmed / 2, pivots, npiv, NULL) +#define DUMBSELECT(v, tosort, left, num, kth) \ + adumb_select_@suff@(v, tosort + left, num, kth) +#else +#define IDX(x) (x) +#define SORTEE(x) v[x] +#define SWAP @TYPE@_SWAP +#define MEDIAN3_SWAP(v, tosort, low, mid, high) \ + median3_swap_@suff@(v, low, mid, high) +#define MEDIAN5(v, tosort, subleft) \ + median5_@suff@(v + subleft) +#define UNGUARDED_PARTITION(v, tosort, pivot, ll, hh) \ + unguarded_partition_@suff@(v, pivot, ll, hh) +#define INTROSELECT(v, tosort, num, kth, pivots, npiv) \ + introselect_@suff@(v, nmed, nmed / 2, pivots, npiv, NULL) +#define DUMBSELECT(v, tosort, left, num, kth) \ + dumb_select_@suff@(v + left, num, kth) +#endif + + +/* + * median of 3 pivot strategy + * gets min and median and moves median to low and min to low + 1 + * for efficient partitioning, see unguarded_partition + */ +static NPY_INLINE void +@name@median3_swap_@suff@(@type@ * v, +#if @arg@ + npy_intp * tosort, +#endif + npy_intp low, npy_intp mid, npy_intp high) +{ + if (@TYPE@_LT(v[IDX(high)], v[IDX(mid)])) + SWAP(SORTEE(high), SORTEE(mid)); + if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) + SWAP(SORTEE(high), SORTEE(low)); + /* move pivot to low */ + if (@TYPE@_LT(v[IDX(low)], v[IDX(mid)])) + SWAP(SORTEE(low), SORTEE(mid)); + /* move 3-lowest element to low + 1 */ + SWAP(SORTEE(mid), SORTEE(low + 1)); +} + + +/* select index of median of five elements */ +static npy_intp @name@median5_@suff@( +#if @arg@ + const @type@ * v, npy_intp * tosort +#else + @type@ * v +#endif + ) +{ + /* could be optimized as we only need the index (no swaps) */ + if (@TYPE@_LT(v[IDX(1)], v[IDX(0)])) { + SWAP(SORTEE(1), SORTEE(0)); + } + if (@TYPE@_LT(v[IDX(4)], v[IDX(3)])) { + SWAP(SORTEE(4), SORTEE(3)); + } + if (@TYPE@_LT(v[IDX(3)], v[IDX(0)])) { + SWAP(SORTEE(3), SORTEE(0)); + } + if (@TYPE@_LT(v[IDX(4)], v[IDX(1)])) { + SWAP(SORTEE(4), SORTEE(1)); + } + if (@TYPE@_LT(v[IDX(2)], v[IDX(1)])) { + SWAP(SORTEE(2), SORTEE(1)); + } + if (@TYPE@_LT(v[IDX(3)], v[IDX(2)])) { + if (@TYPE@_LT(v[IDX(3)], v[IDX(1)])) { + return 1; + } + else { + return 3; + } + } + else { + if (@TYPE@_LT(v[IDX(2)], v[IDX(1)])) { + return 1; + } + else { + return 2; + } + } +} + + +/* + * partition and return the index were the pivot belongs + * the data must have following property to avoid bound checks: + * ll ... hh + * lower-than-pivot [x x x x] larger-than-pivot + */ +static NPY_INLINE void +@name@unguarded_partition_@suff@(@type@ * v, +#if @arg@ + npy_intp * tosort, +#endif + const @type@ pivot, + npy_intp * ll, npy_intp * hh) +{ + for (;;) { + do (*ll)++; while (@TYPE@_LT(v[IDX(*ll)], pivot)); + do (*hh)--; while (@TYPE@_LT(pivot, v[IDX(*hh)])); + + if (*hh < *ll) + break; + + SWAP(SORTEE(*ll), SORTEE(*hh)); + } +} + + +/* + * select median of median of blocks of 5 + * if used as partition pivot it splits the range into at least 30%/70% + * allowing linear time worstcase quickselect + */ +static npy_intp +@name@median_of_median5_@suff@(@type@ *v, +#if @arg@ + npy_intp* tosort, +#endif + const npy_intp num, + npy_intp * pivots, + npy_intp * npiv) +{ + npy_intp i, subleft; + npy_intp right = num - 1; + npy_intp nmed = (right + 1) / 5; + for (i = 0, subleft = 0; i < nmed; i++, subleft += 5) { + npy_intp m = MEDIAN5(v, tosort, subleft); + SWAP(SORTEE(subleft + m), SORTEE(i)); + } + + if (nmed > 2) + INTROSELECT(v, tosort, nmed, nmed / 2, pivots, npiv); + return nmed / 2; +} + + +/* + * N^2 selection, fast only for very small kth + * useful for close multiple partitions + * (e.g. even element median, interpolating percentile) + */ +static int +@name@dumb_select_@suff@(@type@ *v, +#if @arg@ + npy_intp * tosort, +#endif + npy_intp num, npy_intp kth) +{ + npy_intp i; + for (i = 0; i <= kth; i++) { + npy_intp minidx = i; + @type@ minval = v[IDX(i)]; + npy_intp k; + for (k = i + 1; k < num; k++) { + if (@TYPE@_LT(v[IDX(k)], minval)) { + minidx = k; + minval = v[IDX(k)]; + } + } + SWAP(SORTEE(i), SORTEE(minidx)); + } + + return 0; +} + + +/* + * iterative median of 3 quickselect with cutoff to median-of-medians-of5 + * receives stack of already computed pivots in v to minimize the + * partition size were kth is searched in + * + * area that needs partitioning in [...] + * kth 0: [8 7 6 5 4 3 2 1 0] -> med3 partitions elements [4, 2, 0] + * 0 1 2 3 4 8 7 5 6 -> pop requested kth -> stack [4, 2] + * kth 3: 0 1 2 [3] 4 8 7 5 6 -> stack [4] + * kth 5: 0 1 2 3 4 [8 7 5 6] -> stack [6] + * kth 8: 0 1 2 3 4 5 6 [8 7] -> stack [] + * + */ +int +@name@introselect_@suff@(@type@ *v, +#if @arg@ + npy_intp* tosort, +#endif + npy_intp num, npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED) +{ + npy_intp low = 0; + npy_intp high = num - 1; + npy_intp depth_limit; + + if (npiv == NULL) + pivots = NULL; + + while (pivots != NULL && *npiv > 0) { + if (pivots[*npiv - 1] > kth) { + /* pivot larger than kth set it as upper bound */ + high = pivots[*npiv - 1] - 1; + break; + } + else if (pivots[*npiv - 1] == kth) { + /* kth was already found in a previous iteration -> done */ + return 0; + } + + low = pivots[*npiv - 1] + 1; + + /* pop from stack */ + *npiv -= 1; + } + + /* + * use a faster O(n*kth) algorithm for very small kth + * e.g. for interpolating percentile + */ + if (kth - low < 3) { + DUMBSELECT(v, tosort, low, high - low + 1, kth - low); + store_pivot(kth, kth, pivots, npiv); + return 0; + } + + /* dumb integer msb, float npy_log2 too slow for small parititions */ + { + npy_uintp unum = num; + depth_limit = 0; + while (unum >>= 1) { + depth_limit++; + } + depth_limit *= 2; + } + + /* guarantee three elements */ + for (;low + 1 < high;) { + npy_intp ll = low + 1; + npy_intp hh = high; + + /* + * if we aren't making sufficient progress with median of 3 + * fall back to median-of-median5 pivot for linear worst case + * med3 for small sizes is required to do unguarded partition + */ + if (depth_limit > 0 || hh - ll < 5) { + const npy_intp mid = low + (high - low) / 2; + /* median of 3 pivot strategy, + * swapping for efficient partition */ + MEDIAN3_SWAP(v, tosort, low, mid, high); + } + else { + npy_intp mid; + /* FIXME: always use pivots to optimize this iterative partition */ +#if @arg@ + mid = ll + amedian_of_median5_@suff@(v, tosort + ll, hh - ll, NULL, NULL); +#else + mid = ll + median_of_median5_@suff@(v + ll, hh - ll, NULL, NULL); +#endif + SWAP(SORTEE(mid), SORTEE(low)); + /* adapt for the larger partition than med3 pivot */ + ll--; + hh++; + } + + depth_limit--; + + /* + * find place to put pivot (in low): + * previous swapping removes need for bound checks + * pivot 3-lowest [x x x] 3-highest + */ + UNGUARDED_PARTITION(v, tosort, v[IDX(low)], &ll, &hh); + + /* move pivot into position */ + SWAP(SORTEE(low), SORTEE(hh)); + + store_pivot(hh, kth, pivots, npiv); + + if (hh >= kth) + high = hh - 1; + if (hh <= kth) + low = ll; + } + + /* two elements */ + if (high == low + 1) { + if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) + SWAP(SORTEE(high), SORTEE(low)) + store_pivot(low, kth, pivots, npiv); + } + + return 0; +} + + +#undef IDX +#undef SWAP +#undef SORTEE +#undef MEDIAN3_SWAP +#undef MEDIAN5 +#undef UNGUARDED_PARTITION +#undef INTROSELECT +#undef DUMBSELECT +/**end repeat1**/ + +/**end repeat**/ + + +/* + ***************************************************************************** + ** STRING SORTS ** + ***************************************************************************** + */ + + +/**begin repeat + * + * #TYPE = STRING, UNICODE# + * #suff = string, unicode# + * #type = npy_char, npy_ucs4# + */ + +int +introselect_@suff@(@type@ *start, npy_intp num, npy_intp kth, PyArrayObject *arr) +{ + return quicksort_@suff@(start, num, arr); +} + +int +aintroselect_@suff@(@type@ *v, npy_intp* tosort, npy_intp num, npy_intp kth, void *null) +{ + return aquicksort_@suff@(v, tosort, num, null); +} + +/**end repeat**/ + + +/* + ***************************************************************************** + ** GENERIC SORT ** + ***************************************************************************** + */ + + +/* + * This sort has almost the same signature as libc qsort and is intended to + * supply an error return for compatibility with the other generic sort + * kinds. + */ +int +npy_introselect(void *base, size_t num, size_t size, size_t kth, npy_comparator cmp) +{ + return npy_quicksort(base, num, size, cmp); +} diff --git a/numpy/core/src/private/npy_partition.h b/numpy/core/src/private/npy_partition.h new file mode 100644 index 000000000..9e1c8494a --- /dev/null +++ b/numpy/core/src/private/npy_partition.h @@ -0,0 +1,559 @@ + +/* + ***************************************************************************** + ** This file was autogenerated from a template DO NOT EDIT!!!! ** + ** Changes should be made to the original source (.src) file ** + ***************************************************************************** + */ + +#line 1 +/* + ***************************************************************************** + ** IMPORTANT NOTE for npy_partition.h.src -> npy_partition.h ** + ***************************************************************************** + * The template file loops.h.src is not automatically converted into + * loops.h by the build system. If you edit this file, you must manually + * do the conversion using numpy/distutils/conv_template.py from the + * command line as follows: + * + * $ cd <NumPy source root directory> + * $ python numpy/distutils/conv_template.py numpy/core/src/private/npy_partition.h.src + * $ + */ + + +#ifndef __NPY_PARTITION_H__ +#define __NPY_PARTITION_H__ + + +#include "npy_sort.h" + +/* Python include is for future object sorts */ +#include <Python.h> +#include <numpy/npy_common.h> +#include <numpy/ndarraytypes.h> + +#define NPY_MAX_PIVOT_STACK 50 + + +#line 43 + +int introselect_bool(npy_bool *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_bool(npy_bool *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_byte(npy_byte *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_byte(npy_byte *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_ubyte(npy_ubyte *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_ubyte(npy_ubyte *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_short(npy_short *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_short(npy_short *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_ushort(npy_ushort *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_ushort(npy_ushort *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_int(npy_int *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_int(npy_int *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_uint(npy_uint *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_uint(npy_uint *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_long(npy_long *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_long(npy_long *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_ulong(npy_ulong *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_ulong(npy_ulong *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_longlong(npy_longlong *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_longlong(npy_longlong *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_ulonglong(npy_ulonglong *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_ulonglong(npy_ulonglong *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_half(npy_ushort *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_half(npy_ushort *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_float(npy_float *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_float(npy_float *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_double(npy_double *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_double(npy_double *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_longdouble(npy_longdouble *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_longdouble(npy_longdouble *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_cfloat(npy_cfloat *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_cfloat(npy_cfloat *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_cdouble(npy_cdouble *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_cdouble(npy_cdouble *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + +#line 43 + +int introselect_clongdouble(npy_clongdouble *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_clongdouble(npy_clongdouble *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + + + +int introselect_string(npy_char *vec, npy_intp cnt, npy_intp kth, PyArrayObject *arr); +int aintroselect_string(npy_char *vec, npy_intp *ind, npy_intp cnt, npy_intp kth, void *null); + + +int introselect_unicode(npy_ucs4 *vec, npy_intp cnt, npy_intp kth, PyArrayObject *arr); +int aintroselect_unicode(npy_ucs4 *vec, npy_intp *ind, npy_intp cnt, npy_intp kth, void *null); + +int npy_introselect(void *base, size_t num, size_t size, size_t kth, npy_comparator cmp); + +typedef struct { + enum NPY_TYPES typenum; + PyArray_PartitionFunc * part[NPY_NSELECTS]; + PyArray_ArgPartitionFunc * argpart[NPY_NSELECTS]; +} part_map; + +static part_map _part_map[] = { +#line 87 + { + NPY_BOOL, + { + (PyArray_PartitionFunc *)&introselect_bool, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_bool, + } + }, + +#line 87 + { + NPY_BYTE, + { + (PyArray_PartitionFunc *)&introselect_byte, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_byte, + } + }, + +#line 87 + { + NPY_UBYTE, + { + (PyArray_PartitionFunc *)&introselect_ubyte, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_ubyte, + } + }, + +#line 87 + { + NPY_SHORT, + { + (PyArray_PartitionFunc *)&introselect_short, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_short, + } + }, + +#line 87 + { + NPY_USHORT, + { + (PyArray_PartitionFunc *)&introselect_ushort, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_ushort, + } + }, + +#line 87 + { + NPY_INT, + { + (PyArray_PartitionFunc *)&introselect_int, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_int, + } + }, + +#line 87 + { + NPY_UINT, + { + (PyArray_PartitionFunc *)&introselect_uint, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_uint, + } + }, + +#line 87 + { + NPY_LONG, + { + (PyArray_PartitionFunc *)&introselect_long, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_long, + } + }, + +#line 87 + { + NPY_ULONG, + { + (PyArray_PartitionFunc *)&introselect_ulong, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_ulong, + } + }, + +#line 87 + { + NPY_LONGLONG, + { + (PyArray_PartitionFunc *)&introselect_longlong, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_longlong, + } + }, + +#line 87 + { + NPY_ULONGLONG, + { + (PyArray_PartitionFunc *)&introselect_ulonglong, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_ulonglong, + } + }, + +#line 87 + { + NPY_HALF, + { + (PyArray_PartitionFunc *)&introselect_half, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_half, + } + }, + +#line 87 + { + NPY_FLOAT, + { + (PyArray_PartitionFunc *)&introselect_float, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_float, + } + }, + +#line 87 + { + NPY_DOUBLE, + { + (PyArray_PartitionFunc *)&introselect_double, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_double, + } + }, + +#line 87 + { + NPY_LONGDOUBLE, + { + (PyArray_PartitionFunc *)&introselect_longdouble, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_longdouble, + } + }, + +#line 87 + { + NPY_CFLOAT, + { + (PyArray_PartitionFunc *)&introselect_cfloat, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_cfloat, + } + }, + +#line 87 + { + NPY_CDOUBLE, + { + (PyArray_PartitionFunc *)&introselect_cdouble, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_cdouble, + } + }, + +#line 87 + { + NPY_CLONGDOUBLE, + { + (PyArray_PartitionFunc *)&introselect_clongdouble, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_clongdouble, + } + }, + +}; + + +static NPY_INLINE PyArray_PartitionFunc * +get_partition_func(int type, NPY_SELECTKIND which) +{ + npy_intp i; + if (which >= NPY_NSELECTS) { + return NULL; + } + for (i = 0; i < sizeof(_part_map)/sizeof(_part_map[0]); i++) { + if (type == _part_map[i].typenum) { + return _part_map[i].part[which]; + } + } + return NULL; +} + + +static NPY_INLINE PyArray_ArgPartitionFunc * +get_argpartition_func(int type, NPY_SELECTKIND which) +{ + npy_intp i; + if (which >= NPY_NSELECTS) { + return NULL; + } + for (i = 0; i < sizeof(_part_map)/sizeof(_part_map[0]); i++) { + if (type == _part_map[i].typenum) { + return _part_map[i].argpart[which]; + } + } + return NULL; +} + +#endif + diff --git a/numpy/core/src/private/npy_partition.h.src b/numpy/core/src/private/npy_partition.h.src new file mode 100644 index 000000000..c03ecb0de --- /dev/null +++ b/numpy/core/src/private/npy_partition.h.src @@ -0,0 +1,131 @@ +/* + ***************************************************************************** + ** IMPORTANT NOTE for npy_partition.h.src -> npy_partition.h ** + ***************************************************************************** + * The template file loops.h.src is not automatically converted into + * loops.h by the build system. If you edit this file, you must manually + * do the conversion using numpy/distutils/conv_template.py from the + * command line as follows: + * + * $ cd <NumPy source root directory> + * $ python numpy/distutils/conv_template.py numpy/core/src/private/npy_partition.h.src + * $ + */ + + +#ifndef __NPY_PARTITION_H__ +#define __NPY_PARTITION_H__ + + +#include "npy_sort.h" + +/* Python include is for future object sorts */ +#include <Python.h> +#include <numpy/npy_common.h> +#include <numpy/ndarraytypes.h> + +#define NPY_MAX_PIVOT_STACK 50 + + +/**begin repeat + * + * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE# + * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, + * longlong, ulonglong, half, float, double, longdouble, + * cfloat, cdouble, clongdouble# + * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, + * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong, + * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat, + * npy_cdouble, npy_clongdouble# + */ + +int introselect_@suff@(@type@ *v, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); +int aintroselect_@suff@(@type@ *v, npy_intp* tosort, npy_intp num, + npy_intp kth, + npy_intp * pivots, + npy_intp * npiv, + void *NOT_USED); + + +/**end repeat**/ + +int introselect_string(npy_char *vec, npy_intp cnt, npy_intp kth, PyArrayObject *arr); +int aintroselect_string(npy_char *vec, npy_intp *ind, npy_intp cnt, npy_intp kth, void *null); + + +int introselect_unicode(npy_ucs4 *vec, npy_intp cnt, npy_intp kth, PyArrayObject *arr); +int aintroselect_unicode(npy_ucs4 *vec, npy_intp *ind, npy_intp cnt, npy_intp kth, void *null); + +int npy_introselect(void *base, size_t num, size_t size, size_t kth, npy_comparator cmp); + +typedef struct { + enum NPY_TYPES typenum; + PyArray_PartitionFunc * part[NPY_NSELECTS]; + PyArray_ArgPartitionFunc * argpart[NPY_NSELECTS]; +} part_map; + +static part_map _part_map[] = { +/**begin repeat + * + * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, + * CFLOAT, CDOUBLE, CLONGDOUBLE# + * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, + * longlong, ulonglong, half, float, double, longdouble, + * cfloat, cdouble, clongdouble# + * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, + * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong, + * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat, + * npy_cdouble, npy_clongdouble# + */ + { + NPY_@TYPE@, + { + (PyArray_PartitionFunc *)&introselect_@suff@, + }, + { + (PyArray_ArgPartitionFunc *)&aintroselect_@suff@, + } + }, +/**end repeat**/ +}; + + +static NPY_INLINE PyArray_PartitionFunc * +get_partition_func(int type, NPY_SELECTKIND which) +{ + npy_intp i; + if (which >= NPY_NSELECTS) { + return NULL; + } + for (i = 0; i < sizeof(_part_map)/sizeof(_part_map[0]); i++) { + if (type == _part_map[i].typenum) { + return _part_map[i].part[which]; + } + } + return NULL; +} + + +static NPY_INLINE PyArray_ArgPartitionFunc * +get_argpartition_func(int type, NPY_SELECTKIND which) +{ + npy_intp i; + if (which >= NPY_NSELECTS) { + return NULL; + } + for (i = 0; i < sizeof(_part_map)/sizeof(_part_map[0]); i++) { + if (type == _part_map[i].typenum) { + return _part_map[i].argpart[which]; + } + } + return NULL; +} + +#endif diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index c298b4f1d..19806bd76 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -18,7 +18,8 @@ from numpy.core.multiarray_tests import ( from numpy.testing import ( TestCase, run_module_suite, assert_, assert_raises, assert_equal, assert_almost_equal, assert_array_equal, - assert_array_almost_equal, assert_allclose, runstring, dec + assert_array_almost_equal, assert_allclose, + assert_array_less, runstring, dec ) # Need to test an object that does not fully implement math interface @@ -921,6 +922,293 @@ class TestMethods(TestCase): assert_equal(a.searchsorted(k, side='l', sorter=s), [0, 20, 40, 60, 80]) assert_equal(a.searchsorted(k, side='r', sorter=s), [20, 40, 60, 80, 100]) + + def test_partition(self): + d = np.arange(10) + assert_raises(TypeError, np.partition, d, 2, kind=1) + assert_raises(ValueError, np.partition, d, 2, kind="nonsense") + assert_raises(ValueError, np.argpartition, d, 2, kind="nonsense") + assert_raises(ValueError, d.partition, 2, axis=0, kind="nonsense") + assert_raises(ValueError, d.argpartition, 2, axis=0, kind="nonsense") + for k in ("introselect",): + d = np.array([]) + assert_array_equal(np.partition(d, 0, kind=k), d) + assert_array_equal(np.argpartition(d, 0, kind=k), d) + d = np.ones((1)) + assert_array_equal(np.partition(d, 0, kind=k)[0], d) + assert_array_equal(d[np.argpartition(d, 0, kind=k)], + np.partition(d, 0, kind=k)) + + # kth not modified + kth = np.array([30, 15, 5]) + okth = kth.copy() + np.partition(np.arange(40), kth) + assert_array_equal(kth, okth) + + for r in ([2, 1], [1, 2], [1, 1]): + d = np.array(r) + tgt = np.sort(d) + assert_array_equal(np.partition(d, 0, kind=k)[0], tgt[0]) + assert_array_equal(np.partition(d, 1, kind=k)[1], tgt[1]) + assert_array_equal(d[np.argpartition(d, 0, kind=k)], + np.partition(d, 0, kind=k)) + assert_array_equal(d[np.argpartition(d, 1, kind=k)], + np.partition(d, 1, kind=k)) + for i in range(d.size): + d[i:].partition(0, kind=k) + assert_array_equal(d, tgt) + + for r in ([3, 2, 1], [1, 2, 3], [2, 1, 3], [2, 3, 1], + [1, 1, 1], [1, 2, 2], [2, 2, 1], [1, 2, 1]): + d = np.array(r) + tgt = np.sort(d) + assert_array_equal(np.partition(d, 0, kind=k)[0], tgt[0]) + assert_array_equal(np.partition(d, 1, kind=k)[1], tgt[1]) + assert_array_equal(np.partition(d, 2, kind=k)[2], tgt[2]) + assert_array_equal(d[np.argpartition(d, 0, kind=k)], + np.partition(d, 0, kind=k)) + assert_array_equal(d[np.argpartition(d, 1, kind=k)], + np.partition(d, 1, kind=k)) + assert_array_equal(d[np.argpartition(d, 2, kind=k)], + np.partition(d, 2, kind=k)) + for i in range(d.size): + d[i:].partition(0, kind=k) + assert_array_equal(d, tgt) + + d = np.ones((50)) + assert_array_equal(np.partition(d, 0, kind=k), d) + assert_array_equal(d[np.argpartition(d, 0, kind=k)], + np.partition(d, 0, kind=k)) + + # sorted + d = np.arange((49)) + self.assertEqual(np.partition(d, 5, kind=k)[5], 5) + self.assertEqual(np.partition(d, 15, kind=k)[15], 15) + assert_array_equal(d[np.argpartition(d, 5, kind=k)], + np.partition(d, 5, kind=k)) + assert_array_equal(d[np.argpartition(d, 15, kind=k)], + np.partition(d, 15, kind=k)) + + # rsorted + d = np.arange((47))[::-1] + self.assertEqual(np.partition(d, 6, kind=k)[6], 6) + self.assertEqual(np.partition(d, 16, kind=k)[16], 16) + assert_array_equal(d[np.argpartition(d, 6, kind=k)], + np.partition(d, 6, kind=k)) + assert_array_equal(d[np.argpartition(d, 16, kind=k)], + np.partition(d, 16, kind=k)) + + assert_array_equal(np.partition(d, -6, kind=k), + np.partition(d, 41, kind=k)) + assert_array_equal(np.partition(d, -16, kind=k), + np.partition(d, 31, kind=k)) + assert_array_equal(d[np.argpartition(d, -6, kind=k)], + np.partition(d, 41, kind=k)) + + # equal elements + d = np.arange((47)) % 7 + tgt = np.sort(np.arange((47)) % 7) + np.random.shuffle(d) + for i in range(d.size): + self.assertEqual(np.partition(d, i, kind=k)[i], tgt[i]) + assert_array_equal(d[np.argpartition(d, 6, kind=k)], + np.partition(d, 6, kind=k)) + assert_array_equal(d[np.argpartition(d, 16, kind=k)], + np.partition(d, 16, kind=k)) + for i in range(d.size): + d[i:].partition(0, kind=k) + assert_array_equal(d, tgt) + + d = np.array([2, 1]) + d.partition(0, kind=k) + assert_raises(ValueError, d.partition, 2) + assert_raises(ValueError, d.partition, 3, axis=1) + assert_raises(ValueError, np.partition, d, 2) + assert_raises(ValueError, np.partition, d, 2, axis=1) + assert_raises(ValueError, d.argpartition, 2) + assert_raises(ValueError, d.argpartition, 3, axis=1) + assert_raises(ValueError, np.argpartition, d, 2) + assert_raises(ValueError, np.argpartition, d, 2, axis=1) + d = np.arange(10).reshape((2, 5)) + d.partition(1, axis=0, kind=k) + d.partition(4, axis=1, kind=k) + np.partition(d, 1, axis=0, kind=k) + np.partition(d, 4, axis=1, kind=k) + np.partition(d, 1, axis=None, kind=k) + np.partition(d, 9, axis=None, kind=k) + d.argpartition(1, axis=0, kind=k) + d.argpartition(4, axis=1, kind=k) + np.argpartition(d, 1, axis=0, kind=k) + np.argpartition(d, 4, axis=1, kind=k) + np.argpartition(d, 1, axis=None, kind=k) + np.argpartition(d, 9, axis=None, kind=k) + assert_raises(ValueError, d.partition, 2, axis=0) + assert_raises(ValueError, d.partition, 11, axis=1) + assert_raises(TypeError, d.partition, 2, axis=None) + assert_raises(ValueError, np.partition, d, 9, axis=1) + assert_raises(ValueError, np.partition, d, 11, axis=None) + assert_raises(ValueError, d.argpartition, 2, axis=0) + assert_raises(ValueError, d.argpartition, 11, axis=1) + assert_raises(ValueError, np.argpartition, d, 9, axis=1) + assert_raises(ValueError, np.argpartition, d, 11, axis=None) + + td = [(dt, s) for dt in [np.int32, np.float32, np.complex64] + for s in (9, 16)] + for dt, s in td: + aae = assert_array_equal + at = self.assertTrue + + d = np.arange(s, dtype=dt) + np.random.shuffle(d) + d1 = np.tile(np.arange(s, dtype=dt), (4, 1)) + map(np.random.shuffle, d1) + d0 = np.transpose(d1) + for i in range(d.size): + p = np.partition(d, i, kind=k) + self.assertEqual(p[i], i) + # all before are smaller + assert_array_less(p[:i], p[i]) + # all after are larger + assert_array_less(p[i], p[i + 1:]) + aae(p, d[np.argpartition(d, i, kind=k)]) + + p = np.partition(d1, i, axis=1, kind=k) + aae(p[:,i], np.array([i] * d1.shape[0], dtype=dt)) + # array_less does not seem to work right + at((p[:, :i].T <= p[:, i]).all(), + msg="%d: %r <= %r" % (i, p[:, i], p[:, :i].T)) + at((p[:, i + 1:].T > p[:, i]).all(), + msg="%d: %r < %r" % (i, p[:, i], p[:, i + 1:].T)) + aae(p, d1[np.arange(d1.shape[0])[:,None], + np.argpartition(d1, i, axis=1, kind=k)]) + + p = np.partition(d0, i, axis=0, kind=k) + aae(p[i, :], np.array([i] * d1.shape[0], + dtype=dt)) + # array_less does not seem to work right + at((p[:i, :] <= p[i, :]).all(), + msg="%d: %r <= %r" % (i, p[i, :], p[:i, :])) + at((p[i + 1:, :] > p[i, :]).all(), + msg="%d: %r < %r" % (i, p[i, :], p[:, i + 1:])) + aae(p, d0[np.argpartition(d0, i, axis=0, kind=k), + np.arange(d0.shape[1])[None,:]]) + + # check inplace + dc = d.copy() + dc.partition(i, kind=k) + assert_equal(dc, np.partition(d, i, kind=k)) + dc = d0.copy() + dc.partition(i, axis=0, kind=k) + assert_equal(dc, np.partition(d0, i, axis=0, kind=k)) + dc = d1.copy() + dc.partition(i, axis=1, kind=k) + assert_equal(dc, np.partition(d1, i, axis=1, kind=k)) + + + def assert_partitioned(self, d, kth): + prev = 0 + for k in np.sort(kth): + assert_array_less(d[prev:k], d[k], err_msg='kth %d' % k) + assert_((d[k:] >= d[k]).all(), + msg="kth %d, %r not greater equal %d" % (k, d[k:], d[k])) + prev = k + 1 + + + def test_partition_iterative(self): + d = np.arange(17) + kth = (0, 1, 2, 429, 231) + assert_raises(ValueError, d.partition, kth) + assert_raises(ValueError, d.argpartition, kth) + d = np.arange(10).reshape((2, 5)) + assert_raises(ValueError, d.partition, kth, axis=0) + assert_raises(ValueError, d.partition, kth, axis=1) + assert_raises(ValueError, np.partition, d, kth, axis=1) + assert_raises(ValueError, np.partition, d, kth, axis=None) + + d = np.array([3, 4, 2, 1]) + p = np.partition(d, (0, 3)) + self.assert_partitioned(p, (0, 3)) + self.assert_partitioned(d[np.argpartition(d, (0, 3))], (0, 3)) + + assert_array_equal(p, np.partition(d, (-3, -1))) + assert_array_equal(p, d[np.argpartition(d, (-3, -1))]) + + d = np.arange(17) + np.random.shuffle(d) + d.partition(range(d.size)) + assert_array_equal(np.arange(17), d) + np.random.shuffle(d) + assert_array_equal(np.arange(17), d[d.argpartition(range(d.size))]) + + # test unsorted kth + d = np.arange(17) + np.random.shuffle(d) + keys = np.array([1, 3, 8, -2]) + np.random.shuffle(d) + p = np.partition(d, keys) + self.assert_partitioned(p, keys) + p = d[np.argpartition(d, keys)] + self.assert_partitioned(p, keys) + np.random.shuffle(keys) + assert_array_equal(np.partition(d, keys), p) + assert_array_equal(d[np.argpartition(d, keys)], p) + + # equal kth + d = np.arange(20)[::-1] + self.assert_partitioned(np.partition(d, [5]*4), [5]) + self.assert_partitioned(np.partition(d, [5]*4 + [6, 13]), + [5]*4 + [6, 13]) + self.assert_partitioned(d[np.argpartition(d, [5]*4)], [5]) + self.assert_partitioned(d[np.argpartition(d, [5]*4 + [6, 13])], + [5]*4 + [6, 13]) + + d = np.arange(12) + np.random.shuffle(d) + d1 = np.tile(np.arange(12), (4, 1)) + map(np.random.shuffle, d1) + d0 = np.transpose(d1) + + kth = (1, 6, 7, -1) + p = np.partition(d1, kth, axis=1) + pa = d1[np.arange(d1.shape[0])[:,None], + d1.argpartition(kth, axis=1)] + assert_array_equal(p, pa) + for i in range(d1.shape[0]): + self.assert_partitioned(p[i, :], kth) + p = np.partition(d0, kth, axis=0) + pa = d0[np.argpartition(d0, kth, axis=0), + np.arange(d0.shape[1])[None,:]] + assert_array_equal(p, pa) + for i in range(d0.shape[1]): + self.assert_partitioned(p[:, i], kth) + + + def test_partition_cdtype(self): + d = array([('Galahad', 1.7, 38), ('Arthur', 1.8, 41), + ('Lancelot', 1.9, 38)], + dtype=[('name', '|S10'), ('height', '<f8'), ('age', '<i4')]) + + tgt = np.sort(d, order=['age', 'height']) + assert_array_equal(np.partition(d, range(d.size), + order=['age', 'height']), + tgt) + assert_array_equal(d[np.argpartition(d, range(d.size), + order=['age', 'height'])], + tgt) + for k in range(d.size): + assert_equal(np.partition(d, k, order=['age', 'height'])[k], + tgt[k]) + assert_equal(d[np.argpartition(d, k, order=['age', 'height'])][k], + tgt[k]) + + d = array(['Galahad', 'Arthur', 'zebra', 'Lancelot']) + tgt = np.sort(d) + assert_array_equal(np.partition(d, range(d.size)), tgt) + for k in range(d.size): + assert_equal(np.partition(d, k)[k], tgt[k]) + assert_equal(d[np.argpartition(d, k)][k], tgt[k]) + + def test_flatten(self): x0 = np.array([[1,2,3],[4,5,6]], np.int32) x1 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], np.int32) |