summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py44
-rw-r--r--numpy/lib/tests/test_function_base.py47
2 files changed, 79 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 17f99a065..b3b5ef735 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \
integer, isscalar
from numpy.core.umath import pi, multiply, add, arctan2, \
frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10
-from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean
+from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean
from numpy.core.numerictypes import typecodes, number
from numpy.core import atleast_1d, atleast_2d
from numpy.lib.twodim_base import diag
@@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False):
>>> assert not np.all(a==b)
"""
+ if axis is not None and axis >= a.ndim:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+
if overwrite_input:
if axis is None:
- sorted = a.ravel()
- sorted.sort()
+ part = a.ravel()
+ sz = part.size
+ if sz % 2 == 0:
+ szh = sz // 2
+ part.partition((szh - 1, szh))
+ else:
+ part.partition((sz - 1) // 2)
else:
- a.sort(axis=axis)
- sorted = a
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ a.partition((szh - 1, szh), axis=axis)
+ else:
+ a.partition((sz - 1) // 2, axis=axis)
+ part = a
else:
- sorted = sort(a, axis=axis)
- if sorted.shape == ():
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
+ else:
+ part = partition(a, (sz - 1) // 2, axis=axis)
+ if part.shape == ():
# make 0-D arrays work
- return sorted.item()
+ return part.item()
if axis is None:
axis = 0
- indexer = [slice(None)] * sorted.ndim
- index = int(sorted.shape[axis]/2)
- if sorted.shape[axis] % 2 == 1:
+ indexer = [slice(None)] * part.ndim
+ index = part.shape[axis] // 2
+ if part.shape[axis] % 2 == 1:
# index with slice to allow mean (below) to work
indexer[axis] = slice(index, index+1)
else:
indexer[axis] = slice(index-1, index+1)
# Use mean in odd and even case to coerce data type
# and check, use out array.
- return mean(sorted[indexer], axis=axis, out=out)
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None, overwrite_input=False):
"""
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index de561e55a..814743442 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1537,6 +1537,53 @@ def test_median():
assert_allclose(np.median(a2), 2.5)
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
assert_allclose(np.median(a2, axis=1), [1, 4])
+ assert_allclose(np.median(a2, axis=None), 2.5)
+ a3 = np.array([[2, 3],
+ [0, 1],
+ [6, 7],
+ [4, 5]])
+ #check no overwrite
+ for a in [a3, np.random.randint(0, 100, size=(2, 3, 4))]:
+ orig = a.copy()
+ np.median(a, axis=None)
+ for ax in range(a.ndim):
+ np.median(a, axis=ax)
+ assert_array_equal(a, orig)
+
+ assert_allclose(np.median(a3, axis=0), [3, 4])
+ assert_allclose(np.median(a3.T, axis=1), [3, 4])
+ assert_allclose(np.median(a3), 3.5)
+ assert_allclose(np.median(a3, axis=None), 3.5)
+ assert_allclose(np.median(a3.T), 3.5)
+
+ a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
+ assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
+ a = np.array([0.0463301, 0.0444502, 0.141249])
+ assert_almost_equal(a[0], np.median(a))
+ a = np.array([0.0444502, 0.141249, 0.0463301])
+ assert_almost_equal(a[-1], np.median(a))
+
+ assert_allclose(np.median(a0.copy(), overwrite_input=True), 1)
+ assert_allclose(np.median(a1.copy(), overwrite_input=True), 0.5)
+ assert_allclose(np.median(a2.copy(), overwrite_input=True), 2.5)
+ assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=0),
+ [1.5, 2.5, 3.5])
+ assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=1), [1, 4])
+ assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=None), 2.5)
+ assert_allclose(np.median(a3.copy(), overwrite_input=True, axis=0), [3, 4])
+ assert_allclose(np.median(a3.T.copy(), overwrite_input=True, axis=1),
+ [3, 4])
+
+ a4 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
+ map(np.random.shuffle, a4)
+ assert_allclose(np.median(a4, axis=None),
+ np.median(a4.copy(), axis=None, overwrite_input=True))
+ assert_allclose(np.median(a4, axis=0),
+ np.median(a4.copy(), axis=0, overwrite_input=True))
+ assert_allclose(np.median(a4, axis=1),
+ np.median(a4.copy(), axis=1, overwrite_input=True))
+ assert_allclose(np.median(a4, axis=2),
+ np.median(a4.copy(), axis=2, overwrite_input=True))
class TestAdd_newdoc_ufunc(TestCase):