diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 178 |
1 files changed, 112 insertions, 66 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 5742da2bc..edce15776 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2695,6 +2695,66 @@ def msort(a): return b +def _ureduce(a, func, **kwargs): + """ + Internal Function. + Call `func` with `a` as first argument swapping the axes to use extended + axis on functions that don't support it natively. + + Returns result and a.shape with axis dims set to 1. + + Parameters + ---------- + a : array_like + Input array or object that can be converted to an array. + func : callable + Reduction function Kapable of receiving an axis argument. + It is is called with `a` as first argument followed by `kwargs`. + kwargs : keyword arguments + additional keyword arguments to pass to `func`. + + Returns + ------- + result : tuple + Result of func(a, **kwargs) and a.shape with axis dims set to 1 + which can be used to reshape the result to the same shape a ufunc with + keepdims=True would produce. + + """ + a = np.asanyarray(a) + axis = kwargs.get('axis', None) + if axis is not None: + keepdim = list(a.shape) + nd = a.ndim + try: + axis = operator.index(axis) + if axis >= nd or axis < -nd: + raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) + keepdim[axis] = 1 + except TypeError: + sax = set() + for x in axis: + if x >= nd or x < -nd: + raise IndexError("axis %d out of bounds (%d)" % (x, nd)) + if x in sax: + raise ValueError("duplicate value in axis") + sax.add(x % nd) + keepdim[x] = 1 + keep = sax.symmetric_difference(frozenset(range(nd))) + nkeep = len(keep) + # swap axis that should not be reduced to front + for i, s in enumerate(sorted(keep)): + a = a.swapaxes(i, s) + # merge reduced axis + a = a.reshape(a.shape[:nkeep] + (-1,)) + kwargs['axis'] = -1 + else: + keepdim = [1] * a.ndim + + r = func(a, **kwargs) + return r, keepdim + + def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ Compute the median along the specified axis. @@ -2777,76 +2837,62 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): >>> assert not np.all(a==b) """ - a = np.asanyarray(a) - if a.size % 2 == 0: - return percentile(a, q=50., axis=axis, out=out, - overwrite_input=overwrite_input, - interpolation="linear", keepdims=keepdims) + r, k = _ureduce(a, func=_median, axis=axis, out=out, + overwrite_input=overwrite_input) + if keepdims: + return r.reshape(k) else: - # avoid slower weighting path, relevant for small arrays - return percentile(a, q=50., axis=axis, out=out, - overwrite_input=overwrite_input, - interpolation="nearest", keepdims=keepdims) - - -def _ureduce(a, func, **kwargs): - """ - Internal Function. - Call `func` with `a` as first argument swapping the axes to use extended - axis on functions that don't support it natively. - - Returns result and a.shape with axis dims set to 1. - - Parameters - ---------- - a : array_like - Input array or object that can be converted to an array. - func : callable - Reduction function Kapable of receiving an axis argument. - It is is called with `a` as first argument followed by `kwargs`. - kwargs : keyword arguments - additional keyword arguments to pass to `func`. - - Returns - ------- - result : tuple - Result of func(a, **kwargs) and a.shape with axis dims set to 1 - suiteable to use to archive the same result as the ufunc keepdims - argument. + return r - """ +def _median(a, axis=None, out=None, overwrite_input=False): + # can't be reasonably be implemented in terms of percentile as we have to + # call mean to not break astropy a = np.asanyarray(a) - axis = kwargs.get('axis', None) - if axis is not None: - keepdim = list(a.shape) - nd = a.ndim - try: - axis = operator.index(axis) - if axis >= nd or axis < -nd: - raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) - keepdim[axis] = 1 - except TypeError: - sax = set() - for x in axis: - if x >= nd or x < -nd: - raise IndexError("axis %d out of bounds (%d)" % (x, nd)) - if x in sax: - raise ValueError("duplicate value in axis") - sax.add(x % nd) - keepdim[x] = 1 - keep = sax.symmetric_difference(frozenset(range(nd))) - nkeep = len(keep) - # swap axis that should not be reduced to front - for i, s in enumerate(sorted(keep)): - a = a.swapaxes(i, s) - # merge reduced axis - a = a.reshape(a.shape[:nkeep] + (-1,)) - kwargs['axis'] = -1 - else: - keepdim = [1] * a.ndim + if axis is not None and axis >= a.ndim: + raise IndexError( + "axis %d out of bounds (%d)" % (axis, a.ndim)) - r = func(a, **kwargs) - return r, keepdim + if overwrite_input: + if axis is None: + part = a.ravel() + sz = part.size + if sz % 2 == 0: + szh = sz // 2 + part.partition((szh - 1, szh)) + else: + part.partition((sz - 1) // 2) + else: + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + a.partition((szh - 1, szh), axis=axis) + else: + a.partition((sz - 1) // 2, axis=axis) + part = a + else: + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) + else: + part = partition(a, (sz - 1) // 2, axis=axis) + if part.shape == (): + # make 0-D arrays work + return part.item() + if axis is None: + axis = 0 + indexer = [slice(None)] * part.ndim + index = part.shape[axis] // 2 + if part.shape[axis] % 2 == 1: + # index with slice to allow mean (below) to work + indexer[axis] = slice(index, index+1) + else: + indexer[axis] = slice(index-1, index+1) + # Use mean in odd and even case to coerce data type + # and check, use out array. + return mean(part[indexer], axis=axis, out=out) def percentile(a, q, axis=None, out=None, |