diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 44 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 47 |
2 files changed, 79 insertions, 12 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 17f99a065..b3b5ef735 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -22,7 +22,7 @@ from numpy.core.numeric import ScalarType, dot, where, newaxis, intp, \ integer, isscalar from numpy.core.umath import pi, multiply, add, arctan2, \ frompyfunc, isnan, cos, less_equal, sqrt, sin, mod, exp, log10 -from numpy.core.fromnumeric import ravel, nonzero, choose, sort, mean +from numpy.core.fromnumeric import ravel, nonzero, choose, sort, partition, mean from numpy.core.numerictypes import typecodes, number from numpy.core import atleast_1d, atleast_2d from numpy.lib.twodim_base import diag @@ -2996,30 +2996,50 @@ def median(a, axis=None, out=None, overwrite_input=False): >>> assert not np.all(a==b) """ + if axis is not None and axis >= a.ndim: + raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) + if overwrite_input: if axis is None: - sorted = a.ravel() - sorted.sort() + part = a.ravel() + sz = part.size + if sz % 2 == 0: + szh = sz // 2 + part.partition((szh - 1, szh)) + else: + part.partition((sz - 1) // 2) else: - a.sort(axis=axis) - sorted = a + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + a.partition((szh - 1, szh), axis=axis) + else: + a.partition((sz - 1) // 2, axis=axis) + part = a else: - sorted = sort(a, axis=axis) - if sorted.shape == (): + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) + else: + part = partition(a, (sz - 1) // 2, axis=axis) + if part.shape == (): # make 0-D arrays work - return sorted.item() + return part.item() if axis is None: axis = 0 - indexer = [slice(None)] * sorted.ndim - index = int(sorted.shape[axis]/2) - if sorted.shape[axis] % 2 == 1: + indexer = [slice(None)] * part.ndim + index = part.shape[axis] // 2 + if part.shape[axis] % 2 == 1: # index with slice to allow mean (below) to work indexer[axis] = slice(index, index+1) else: indexer[axis] = slice(index-1, index+1) # Use mean in odd and even case to coerce data type # and check, use out array. - return mean(sorted[indexer], axis=axis, out=out) + return mean(part[indexer], axis=axis, out=out) def percentile(a, q, axis=None, out=None, overwrite_input=False): """ diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index de561e55a..814743442 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1537,6 +1537,53 @@ def test_median(): assert_allclose(np.median(a2), 2.5) assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5]) assert_allclose(np.median(a2, axis=1), [1, 4]) + assert_allclose(np.median(a2, axis=None), 2.5) + a3 = np.array([[2, 3], + [0, 1], + [6, 7], + [4, 5]]) + #check no overwrite + for a in [a3, np.random.randint(0, 100, size=(2, 3, 4))]: + orig = a.copy() + np.median(a, axis=None) + for ax in range(a.ndim): + np.median(a, axis=ax) + assert_array_equal(a, orig) + + assert_allclose(np.median(a3, axis=0), [3, 4]) + assert_allclose(np.median(a3.T, axis=1), [3, 4]) + assert_allclose(np.median(a3), 3.5) + assert_allclose(np.median(a3, axis=None), 3.5) + assert_allclose(np.median(a3.T), 3.5) + + a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775]) + assert_almost_equal((a[1] + a[3]) / 2., np.median(a)) + a = np.array([0.0463301, 0.0444502, 0.141249]) + assert_almost_equal(a[0], np.median(a)) + a = np.array([0.0444502, 0.141249, 0.0463301]) + assert_almost_equal(a[-1], np.median(a)) + + assert_allclose(np.median(a0.copy(), overwrite_input=True), 1) + assert_allclose(np.median(a1.copy(), overwrite_input=True), 0.5) + assert_allclose(np.median(a2.copy(), overwrite_input=True), 2.5) + assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=0), + [1.5, 2.5, 3.5]) + assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=1), [1, 4]) + assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=None), 2.5) + assert_allclose(np.median(a3.copy(), overwrite_input=True, axis=0), [3, 4]) + assert_allclose(np.median(a3.T.copy(), overwrite_input=True, axis=1), + [3, 4]) + + a4 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + map(np.random.shuffle, a4) + assert_allclose(np.median(a4, axis=None), + np.median(a4.copy(), axis=None, overwrite_input=True)) + assert_allclose(np.median(a4, axis=0), + np.median(a4.copy(), axis=0, overwrite_input=True)) + assert_allclose(np.median(a4, axis=1), + np.median(a4.copy(), axis=1, overwrite_input=True)) + assert_allclose(np.median(a4, axis=2), + np.median(a4.copy(), axis=2, overwrite_input=True)) class TestAdd_newdoc_ufunc(TestCase): |