summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-08-12 06:28:41 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-08-12 06:28:41 -0700
commitf8efcb6f191621e7c59f5385c85aeaa83be3669d (patch)
treed0029c507a53d4faedc3be104c8f8200eb21e157 /numpy/lib/function_base.py
parent028007e05fd62d37fc98daf91ae0aafc1ea95cb8 (diff)
parent4d9cd695486fa095c6bff3238341a85cbdb47d0e (diff)
downloadnumpy-f8efcb6f191621e7c59f5385c85aeaa83be3669d.tar.gz
Merge pull request #3360 from juliantaylor/selection-algo
add quickselect algorithm and expose it via partition
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 17f99a065..b3b5ef735 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \
integer, isscalar
from numpy.core.umath import pi, multiply, add, arctan2, \
frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10
-from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean
+from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean
from numpy.core.numerictypes import typecodes, number
from numpy.core import atleast_1d, atleast_2d
from numpy.lib.twodim_base import diag
@@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False):
>>> assert not np.all(a==b)
"""
+ if axis is not None and axis >= a.ndim:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+
if overwrite_input:
if axis is None:
- sorted = a.ravel()
- sorted.sort()
+ part = a.ravel()
+ sz = part.size
+ if sz % 2 == 0:
+ szh = sz // 2
+ part.partition((szh - 1, szh))
+ else:
+ part.partition((sz - 1) // 2)
else:
- a.sort(axis=axis)
- sorted = a
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ a.partition((szh - 1, szh), axis=axis)
+ else:
+ a.partition((sz - 1) // 2, axis=axis)
+ part = a
else:
- sorted = sort(a, axis=axis)
- if sorted.shape == ():
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
+ else:
+ part = partition(a, (sz - 1) // 2, axis=axis)
+ if part.shape == ():
# make 0-D arrays work
- return sorted.item()
+ return part.item()
if axis is None:
axis = 0
- indexer = [slice(None)] * sorted.ndim
- index = int(sorted.shape[axis]/2)
- if sorted.shape[axis] % 2 == 1:
+ indexer = [slice(None)] * part.ndim
+ index = part.shape[axis] // 2
+ if part.shape[axis] % 2 == 1:
# index with slice to allow mean (below) to work
indexer[axis] = slice(index, index+1)
else:
indexer[axis] = slice(index-1, index+1)
# Use mean in odd and even case to coerce data type
# and check, use out array.
- return mean(sorted[indexer], axis=axis, out=out)
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None, overwrite_input=False):
"""