summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-02-22 08:54:47 -0700
committerCharles Harris <charlesr.harris@gmail.com>2014-02-22 08:54:47 -0700
commit2882cc96639787335d135bf1fc10e1174c83831f (patch)
treeaff2fe56b4f48b991bd41150ae76b75a50d11d4b
parentbf5fd54026a54358dc1602446af8b406bb51d99c (diff)
parenta676f4eb6c0935686254af62d4e32e88361727a4 (diff)
downloadnumpy-2882cc96639787335d135bf1fc10e1174c83831f.tar.gz
Merge pull request #4347 from juliantaylor/partition-max
ENH: add max fastpath to partition for nan detection
-rw-r--r--numpy/core/src/npysort/selection.c.src15
-rw-r--r--numpy/core/tests/test_multiarray.py10
2 files changed, 25 insertions, 0 deletions
diff --git a/numpy/core/src/npysort/selection.c.src b/numpy/core/src/npysort/selection.c.src
index 073b5847f..920c07ec6 100644
--- a/numpy/core/src/npysort/selection.c.src
+++ b/numpy/core/src/npysort/selection.c.src
@@ -69,6 +69,7 @@ static NPY_INLINE void store_pivot(npy_intp pivot, npy_intp kth,
* npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat,
* npy_cdouble, npy_clongdouble#
+ * #inexact = 0*11, 1*7#
*/
static npy_intp
@@ -322,6 +323,20 @@ int
store_pivot(kth, kth, pivots, npiv);
return 0;
}
+ else if (@inexact@ && kth == num - 1) {
+ /* useful to check if NaN present via partition(d, (x, -1)) */
+ npy_intp k;
+ npy_intp maxidx = low;
+ @type@ maxval = v[IDX(low)];
+ for (k = low + 1; k < num; k++) {
+ if (!@TYPE@_LT(v[IDX(k)], maxval)) {
+ maxidx = k;
+ maxval = v[IDX(k)];
+ }
+ }
+ SWAP(SORTEE(kth), SORTEE(maxidx));
+ return 0;
+ }
/* dumb integer msb, float npy_log2 too slow for small parititions */
{
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a73b3350d..655ac3748 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1299,6 +1299,16 @@ class TestMethods(TestCase):
mid = x.size // 2 + 1
assert_equal(np.partition(x, mid)[mid], mid)
+ # max
+ d = np.ones(10); d[1] = 4;
+ assert_equal(np.partition(d, (2, -1))[-1], 4)
+ assert_equal(np.partition(d, (2, -1))[2], 1)
+ assert_equal(d[np.argpartition(d, (2, -1))][-1], 4)
+ assert_equal(d[np.argpartition(d, (2, -1))][2], 1)
+ d[1] = np.nan
+ assert_(np.isnan(d[np.argpartition(d, (2, -1))][-1]))
+ assert_(np.isnan(np.partition(d, (2, -1))[-1]))
+
# equal elements
d = np.arange((47)) % 7
tgt = np.sort(np.arange((47)) % 7)