diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 17f99a065..b3b5ef735 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \ integer, isscalar from numpy.core.umath import pi, multiply, add, arctan2, \ frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10 -from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean +from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean from numpy.core.numerictypes import typecodes, number from numpy.core import atleast_1d, atleast_2d from numpy.lib.twodim_base import diag @@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False): >>> assert not np.all(a==b) """ + if axis is not None and axis >= a.ndim: + raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) + if overwrite_input: if axis is None: - sorted = a.ravel() - sorted.sort() + part = a.ravel() + sz = part.size + if sz % 2 == 0: + szh = sz // 2 + part.partition((szh - 1, szh)) + else: + part.partition((sz - 1) // 2) else: - a.sort(axis=axis) - sorted = a + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + a.partition((szh - 1, szh), axis=axis) + else: + a.partition((sz - 1) // 2, axis=axis) + part = a else: - sorted = sort(a, axis=axis) - if sorted.shape == (): + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) + else: + part = partition(a, (sz - 1) // 2, axis=axis) + if part.shape == (): # make 0-D arrays work - return sorted.item() + return part.item() if axis is None: axis = 0 - indexer = [slice(None)] * sorted.ndim - index = int(sorted.shape[axis]/2) - if sorted.shape[axis] % 2 == 1: + indexer = [slice(None)] * part.ndim + index = part.shape[axis] // 2 + if part.shape[axis] % 2 == 1: # index with slice to allow mean (below) to work indexer[axis] = slice(index, index+1) else: indexer[axis] = slice(index-1, index+1) # Use mean in odd and even case to coerce data type # and check, use out array. - return mean(sorted[indexer], axis=axis, out=out) + return mean(part[indexer], axis=axis, out=out) def percentile(a, q, axis=None, out=None, overwrite_input=False): """ |