diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-08-12 06:28:41 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-08-12 06:28:41 -0700 |
commit | f8efcb6f191621e7c59f5385c85aeaa83be3669d (patch) | |
tree | d0029c507a53d4faedc3be104c8f8200eb21e157 /numpy/lib/function_base.py | |
parent | 028007e05fd62d37fc98daf91ae0aafc1ea95cb8 (diff) | |
parent | 4d9cd695486fa095c6bff3238341a85cbdb47d0e (diff) | |
download | numpy-f8efcb6f191621e7c59f5385c85aeaa83be3669d.tar.gz |
Merge pull request #3360 from juliantaylor/selection-algo
add quickselect algorithm and expose it via partition
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 17f99a065..b3b5ef735 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \ integer, isscalar from numpy.core.umath import pi, multiply, add, arctan2, \ frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10 -from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean +from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean from numpy.core.numerictypes import typecodes, number from numpy.core import atleast_1d, atleast_2d from numpy.lib.twodim_base import diag @@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False): >>> assert not np.all(a==b) """ + if axis is not None and axis >= a.ndim: + raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) + if overwrite_input: if axis is None: - sorted = a.ravel() - sorted.sort() + part = a.ravel() + sz = part.size + if sz % 2 == 0: + szh = sz // 2 + part.partition((szh - 1, szh)) + else: + part.partition((sz - 1) // 2) else: - a.sort(axis=axis) - sorted = a + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + a.partition((szh - 1, szh), axis=axis) + else: + a.partition((sz - 1) // 2, axis=axis) + part = a else: - sorted = sort(a, axis=axis) - if sorted.shape == (): + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) + else: + part = partition(a, (sz - 1) // 2, axis=axis) + if part.shape == (): # make 0-D arrays work - return sorted.item() + return part.item() if axis is None: axis = 0 - indexer = [slice(None)] * sorted.ndim - index = int(sorted.shape[axis]/2) - if sorted.shape[axis] % 2 == 1: + indexer = [slice(None)] * part.ndim + index = part.shape[axis] // 2 + if part.shape[axis] % 2 == 1: # index with slice to allow mean (below) to work indexer[axis] = slice(index, index+1) else: indexer[axis] = slice(index-1, index+1) # Use mean in odd and even case to coerce data type # and check, use out array. - return mean(sorted[indexer], axis=axis, out=out) + return mean(part[indexer], axis=axis, out=out) def percentile(a, q, axis=None, out=None, overwrite_input=False): """ |