diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-02-18 00:23:56 +0100 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-02-22 11:09:07 +0100 |
commit | a676f4eb6c0935686254af62d4e32e88361727a4 (patch) | |
tree | 178852a6ae1a0658575b4d88160e29363bb3d000 | |
parent | dd857489fcaa745ea569c523249fe37537bfe280 (diff) | |
download | numpy-a676f4eb6c0935686254af62d4e32e88361727a4.tar.gz |
ENH: add max fastpath to partition for nan detection
Allows low overhead check for NaN via isnan(partition(d, (x, -1))[-1])
-rw-r--r-- | numpy/core/src/npysort/selection.c.src | 15 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 10 |
2 files changed, 25 insertions, 0 deletions
diff --git a/numpy/core/src/npysort/selection.c.src b/numpy/core/src/npysort/selection.c.src index 073b5847f..920c07ec6 100644 --- a/numpy/core/src/npysort/selection.c.src +++ b/numpy/core/src/npysort/selection.c.src @@ -69,6 +69,7 @@ static NPY_INLINE void store_pivot(npy_intp pivot, npy_intp kth, * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong, * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat, * npy_cdouble, npy_clongdouble# + * #inexact = 0*11, 1*7# */ static npy_intp @@ -322,6 +323,20 @@ int store_pivot(kth, kth, pivots, npiv); return 0; } + else if (@inexact@ && kth == num - 1) { + /* useful to check if NaN present via partition(d, (x, -1)) */ + npy_intp k; + npy_intp maxidx = low; + @type@ maxval = v[IDX(low)]; + for (k = low + 1; k < num; k++) { + if (!@TYPE@_LT(v[IDX(k)], maxval)) { + maxidx = k; + maxval = v[IDX(k)]; + } + } + SWAP(SORTEE(kth), SORTEE(maxidx)); + return 0; + } /* dumb integer msb, float npy_log2 too slow for small parititions */ { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a73b3350d..655ac3748 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1299,6 +1299,16 @@ class TestMethods(TestCase): mid = x.size // 2 + 1 assert_equal(np.partition(x, mid)[mid], mid) + # max + d = np.ones(10); d[1] = 4; + assert_equal(np.partition(d, (2, -1))[-1], 4) + assert_equal(np.partition(d, (2, -1))[2], 1) + assert_equal(d[np.argpartition(d, (2, -1))][-1], 4) + assert_equal(d[np.argpartition(d, (2, -1))][2], 1) + d[1] = np.nan + assert_(np.isnan(d[np.argpartition(d, (2, -1))][-1])) + assert_(np.isnan(np.partition(d, (2, -1))[-1])) + # equal elements d = np.arange((47)) % 7 tgt = np.sort(np.arange((47)) % 7) |