diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2013-10-13 12:34:43 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-03-13 19:06:20 +0100 |
commit | eea1a9c49024c18fda3ad9782dee3956492cfa1a (patch) | |
tree | 559e94b82e8c4ea05eef34f6421d56ab5d46ad97 /numpy/lib/function_base.py | |
parent | 9c4602f98096abed5632d5fd1f12549a2c5b360c (diff) | |
download | numpy-eea1a9c49024c18fda3ad9782dee3956492cfa1a.tar.gz |
ENH: add extended axis and keepdims support to median and percentile
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 165 |
1 files changed, 113 insertions, 52 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 7b2d077c3..5742da2bc 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -13,6 +13,7 @@ __all__ = [ import warnings import sys import collections +import operator import numpy as np import numpy.core.numeric as _nx @@ -2694,7 +2695,7 @@ def msort(a): return b -def median(a, axis=None, out=None, overwrite_input=False): +def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ Compute the median along the specified axis. @@ -2704,9 +2705,10 @@ def median(a, axis=None, out=None, overwrite_input=False): ---------- a : array_like Input array or object that can be converted to an array. - axis : int, optional + axis : int or sequence of int, optional Axis along which the medians are computed. The default (axis=None) is to compute the median along a flattened version of the array. + A sequence of axes is supported since version 1.9.0. out : ndarray, optional Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, but the @@ -2719,6 +2721,13 @@ def median(a, axis=None, out=None, overwrite_input=False): will probably be fully or partially sorted. Default is False. Note that, if `overwrite_input` is True and the input is not already an ndarray, an error will be raised. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `arr`. + + .. versionadded:: 1.9.0 + Returns ------- @@ -2769,55 +2778,79 @@ def median(a, axis=None, out=None, overwrite_input=False): """ a = np.asanyarray(a) - if axis is not None and axis >= a.ndim: - raise IndexError( - "axis %d out of bounds (%d)" % (axis, a.ndim)) - - if overwrite_input: - if axis is None: - part = a.ravel() - sz = part.size - if sz % 2 == 0: - szh = sz // 2 - part.partition((szh - 1, szh)) - else: - part.partition((sz - 1) // 2) - else: - sz = a.shape[axis] - if sz % 2 == 0: - szh = sz // 2 - a.partition((szh - 1, szh), axis=axis) - else: - a.partition((sz - 1) // 2, axis=axis) - part = a + if a.size % 2 == 0: + return percentile(a, q=50., axis=axis, out=out, + overwrite_input=overwrite_input, + interpolation="linear", keepdims=keepdims) else: - if axis is None: - sz = a.size - else: - sz = a.shape[axis] - if sz % 2 == 0: - part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) - else: - part = partition(a, (sz - 1) // 2, axis=axis) - if part.shape == (): - # make 0-D arrays work - return part.item() - if axis is None: - axis = 0 - indexer = [slice(None)] * part.ndim - index = part.shape[axis] // 2 - if part.shape[axis] % 2 == 1: - # index with slice to allow mean (below) to work - indexer[axis] = slice(index, index+1) + # avoid slower weighting path, relevant for small arrays + return percentile(a, q=50., axis=axis, out=out, + overwrite_input=overwrite_input, + interpolation="nearest", keepdims=keepdims) + + +def _ureduce(a, func, **kwargs): + """ + Internal Function. + Call `func` with `a` as first argument swapping the axes to use extended + axis on functions that don't support it natively. + + Returns result and a.shape with axis dims set to 1. + + Parameters + ---------- + a : array_like + Input array or object that can be converted to an array. + func : callable + Reduction function Kapable of receiving an axis argument. + It is is called with `a` as first argument followed by `kwargs`. + kwargs : keyword arguments + additional keyword arguments to pass to `func`. + + Returns + ------- + result : tuple + Result of func(a, **kwargs) and a.shape with axis dims set to 1 + suiteable to use to archive the same result as the ufunc keepdims + argument. + + """ + a = np.asanyarray(a) + axis = kwargs.get('axis', None) + if axis is not None: + keepdim = list(a.shape) + nd = a.ndim + try: + axis = operator.index(axis) + if axis >= nd or axis < -nd: + raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) + keepdim[axis] = 1 + except TypeError: + sax = set() + for x in axis: + if x >= nd or x < -nd: + raise IndexError("axis %d out of bounds (%d)" % (x, nd)) + if x in sax: + raise ValueError("duplicate value in axis") + sax.add(x % nd) + keepdim[x] = 1 + keep = sax.symmetric_difference(frozenset(range(nd))) + nkeep = len(keep) + # swap axis that should not be reduced to front + for i, s in enumerate(sorted(keep)): + a = a.swapaxes(i, s) + # merge reduced axis + a = a.reshape(a.shape[:nkeep] + (-1,)) + kwargs['axis'] = -1 else: - indexer[axis] = slice(index-1, index+1) - # Use mean in odd and even case to coerce data type - # and check, use out array. - return mean(part[indexer], axis=axis, out=out) + keepdim = [1] * a.ndim + + r = func(a, **kwargs) + return r, keepdim def percentile(a, q, axis=None, out=None, - overwrite_input=False, interpolation='linear'): + overwrite_input=False, interpolation='linear', keepdims=False): """ Compute the qth percentile of the data along the specified axis. @@ -2829,9 +2862,10 @@ def percentile(a, q, axis=None, out=None, Input array or object that can be converted to an array. q : float in range of [0,100] (or sequence of floats) Percentile to compute which must be between 0 and 100 inclusive. - axis : int, optional + axis : int or sequence of int, optional Axis along which the percentiles are computed. The default (None) is to compute the percentiles along a flattened version of the array. + A sequence of axes is supported since version 1.9.0. out : ndarray, optional Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, @@ -2857,6 +2891,12 @@ def percentile(a, q, axis=None, out=None, * midpoint: (`i` + `j`) / 2. .. versionadded:: 1.9.0 + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `arr`. + + .. versionadded:: 1.9.0 Returns ------- @@ -2913,8 +2953,22 @@ def percentile(a, q, axis=None, out=None, array([ 3.5]) """ + q = asarray(q, dtype=np.float64) + r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out, + overwrite_input=overwrite_input, + interpolation=interpolation) + if keepdims: + if q.ndim == 0: + return r.reshape(k) + else: + return r.reshape([len(q)] + k) + else: + return r + + +def _percentile(a, q, axis=None, out=None, + overwrite_input=False, interpolation='linear', keepdims=False): a = asarray(a) - q = asarray(q) if q.ndim == 0: # Do not allow 0-d arrays because following code fails for scalar zerod = True @@ -2922,10 +2976,17 @@ def percentile(a, q, axis=None, out=None, else: zerod = False - q = q / 100.0 - if (q < 0).any() or (q > 1).any(): - raise ValueError( - "Percentiles must be in the range [0,100]") + # avoid expensive reductions, relevant for arrays with < O(1000) elements + if q.size < 10: + for i in range(q.size): + if q[i] < 0. or q[i] > 100.: + raise ValueError("Percentiles must be in the range [0,100]") + q[i] /= 100. + else: + # faster than any() + if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.): + raise ValueError("Percentiles must be in the range [0,100]") + q /= 100. # prepare a for partioning if overwrite_input: |