summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 17f99a065..b3b5ef735 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \
integer, isscalar
from numpy.core.umath import pi, multiply, add, arctan2, \
frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10
-from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean
+from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean
from numpy.core.numerictypes import typecodes, number
from numpy.core import atleast_1d, atleast_2d
from numpy.lib.twodim_base import diag
@@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False):
>>> assert not np.all(a==b)
"""
+ if axis is not None and axis >= a.ndim:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+
if overwrite_input:
if axis is None:
- sorted = a.ravel()
- sorted.sort()
+ part = a.ravel()
+ sz = part.size
+ if sz % 2 == 0:
+ szh = sz // 2
+ part.partition((szh - 1, szh))
+ else:
+ part.partition((sz - 1) // 2)
else:
- a.sort(axis=axis)
- sorted = a
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ a.partition((szh - 1, szh), axis=axis)
+ else:
+ a.partition((sz - 1) // 2, axis=axis)
+ part = a
else:
- sorted = sort(a, axis=axis)
- if sorted.shape == ():
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
+ else:
+ part = partition(a, (sz - 1) // 2, axis=axis)
+ if part.shape == ():
# make 0-D arrays work
- return sorted.item()
+ return part.item()
if axis is None:
axis = 0
- indexer = [slice(None)] * sorted.ndim
- index = int(sorted.shape[axis]/2)
- if sorted.shape[axis] % 2 == 1:
+ indexer = [slice(None)] * part.ndim
+ index = part.shape[axis] // 2
+ if part.shape[axis] % 2 == 1:
# index with slice to allow mean (below) to work
indexer[axis] = slice(index, index+1)
else:
indexer[axis] = slice(index-1, index+1)
# Use mean in odd and even case to coerce data type
# and check, use out array.
- return mean(sorted[indexer], axis=axis, out=out)
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None, overwrite_input=False):
"""