summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2013-10-13 12:34:43 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-03-13 19:06:20 +0100
commiteea1a9c49024c18fda3ad9782dee3956492cfa1a (patch)
tree559e94b82e8c4ea05eef34f6421d56ab5d46ad97 /numpy/lib/function_base.py
parent9c4602f98096abed5632d5fd1f12549a2c5b360c (diff)
downloadnumpy-eea1a9c49024c18fda3ad9782dee3956492cfa1a.tar.gz
ENH: add extended axis and keepdims support to median and percentile
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py165
1 files changed, 113 insertions, 52 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 7b2d077c3..5742da2bc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -13,6 +13,7 @@ __all__ = [
import warnings
import sys
import collections
+import operator
import numpy as np
import numpy.core.numeric as _nx
@@ -2694,7 +2695,7 @@ def msort(a):
return b
-def median(a, axis=None, out=None, overwrite_input=False):
+def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Compute the median along the specified axis.
@@ -2704,9 +2705,10 @@ def median(a, axis=None, out=None, overwrite_input=False):
----------
a : array_like
Input array or object that can be converted to an array.
- axis : int, optional
+ axis : int or sequence of int, optional
Axis along which the medians are computed. The default (axis=None)
is to compute the median along a flattened version of the array.
+ A sequence of axes is supported since version 1.9.0.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and buffer length as the expected output, but the
@@ -2719,6 +2721,13 @@ def median(a, axis=None, out=None, overwrite_input=False):
will probably be fully or partially sorted. Default is False. Note
that, if `overwrite_input` is True and the input is not already an
ndarray, an error will be raised.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the original `arr`.
+
+ .. versionadded:: 1.9.0
+
Returns
-------
@@ -2769,55 +2778,79 @@ def median(a, axis=None, out=None, overwrite_input=False):
"""
a = np.asanyarray(a)
- if axis is not None and axis >= a.ndim:
- raise IndexError(
- "axis %d out of bounds (%d)" % (axis, a.ndim))
-
- if overwrite_input:
- if axis is None:
- part = a.ravel()
- sz = part.size
- if sz % 2 == 0:
- szh = sz // 2
- part.partition((szh - 1, szh))
- else:
- part.partition((sz - 1) // 2)
- else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- szh = sz // 2
- a.partition((szh - 1, szh), axis=axis)
- else:
- a.partition((sz - 1) // 2, axis=axis)
- part = a
+ if a.size % 2 == 0:
+ return percentile(a, q=50., axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation="linear", keepdims=keepdims)
else:
- if axis is None:
- sz = a.size
- else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
- else:
- part = partition(a, (sz - 1) // 2, axis=axis)
- if part.shape == ():
- # make 0-D arrays work
- return part.item()
- if axis is None:
- axis = 0
- indexer = [slice(None)] * part.ndim
- index = part.shape[axis] // 2
- if part.shape[axis] % 2 == 1:
- # index with slice to allow mean (below) to work
- indexer[axis] = slice(index, index+1)
+ # avoid slower weighting path, relevant for small arrays
+ return percentile(a, q=50., axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation="nearest", keepdims=keepdims)
+
+
+def _ureduce(a, func, **kwargs):
+ """
+ Internal Function.
+ Call `func` with `a` as first argument swapping the axes to use extended
+ axis on functions that don't support it natively.
+
+ Returns result and a.shape with axis dims set to 1.
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ func : callable
+ Reduction function Kapable of receiving an axis argument.
+ It is is called with `a` as first argument followed by `kwargs`.
+ kwargs : keyword arguments
+ additional keyword arguments to pass to `func`.
+
+ Returns
+ -------
+ result : tuple
+ Result of func(a, **kwargs) and a.shape with axis dims set to 1
+ suiteable to use to archive the same result as the ufunc keepdims
+ argument.
+
+ """
+ a = np.asanyarray(a)
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ keepdim = list(a.shape)
+ nd = a.ndim
+ try:
+ axis = operator.index(axis)
+ if axis >= nd or axis < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+ keepdim[axis] = 1
+ except TypeError:
+ sax = set()
+ for x in axis:
+ if x >= nd or x < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (x, nd))
+ if x in sax:
+ raise ValueError("duplicate value in axis")
+ sax.add(x % nd)
+ keepdim[x] = 1
+ keep = sax.symmetric_difference(frozenset(range(nd)))
+ nkeep = len(keep)
+ # swap axis that should not be reduced to front
+ for i, s in enumerate(sorted(keep)):
+ a = a.swapaxes(i, s)
+ # merge reduced axis
+ a = a.reshape(a.shape[:nkeep] + (-1,))
+ kwargs['axis'] = -1
else:
- indexer[axis] = slice(index-1, index+1)
- # Use mean in odd and even case to coerce data type
- # and check, use out array.
- return mean(part[indexer], axis=axis, out=out)
+ keepdim = [1] * a.ndim
+
+ r = func(a, **kwargs)
+ return r, keepdim
def percentile(a, q, axis=None, out=None,
- overwrite_input=False, interpolation='linear'):
+ overwrite_input=False, interpolation='linear', keepdims=False):
"""
Compute the qth percentile of the data along the specified axis.
@@ -2829,9 +2862,10 @@ def percentile(a, q, axis=None, out=None,
Input array or object that can be converted to an array.
q : float in range of [0,100] (or sequence of floats)
Percentile to compute which must be between 0 and 100 inclusive.
- axis : int, optional
+ axis : int or sequence of int, optional
Axis along which the percentiles are computed. The default (None)
is to compute the percentiles along a flattened version of the array.
+ A sequence of axes is supported since version 1.9.0.
out : ndarray, optional
Alternative output array in which to place the result. It must
have the same shape and buffer length as the expected output,
@@ -2857,6 +2891,12 @@ def percentile(a, q, axis=None, out=None,
* midpoint: (`i` + `j`) / 2.
.. versionadded:: 1.9.0
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the original `arr`.
+
+ .. versionadded:: 1.9.0
Returns
-------
@@ -2913,8 +2953,22 @@ def percentile(a, q, axis=None, out=None,
array([ 3.5])
"""
+ q = asarray(q, dtype=np.float64)
+ r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation=interpolation)
+ if keepdims:
+ if q.ndim == 0:
+ return r.reshape(k)
+ else:
+ return r.reshape([len(q)] + k)
+ else:
+ return r
+
+
+def _percentile(a, q, axis=None, out=None,
+ overwrite_input=False, interpolation='linear', keepdims=False):
a = asarray(a)
- q = asarray(q)
if q.ndim == 0:
# Do not allow 0-d arrays because following code fails for scalar
zerod = True
@@ -2922,10 +2976,17 @@ def percentile(a, q, axis=None, out=None,
else:
zerod = False
- q = q / 100.0
- if (q < 0).any() or (q > 1).any():
- raise ValueError(
- "Percentiles must be in the range [0,100]")
+ # avoid expensive reductions, relevant for arrays with < O(1000) elements
+ if q.size < 10:
+ for i in range(q.size):
+ if q[i] < 0. or q[i] > 100.:
+ raise ValueError("Percentiles must be in the range [0,100]")
+ q[i] /= 100.
+ else:
+ # faster than any()
+ if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.):
+ raise ValueError("Percentiles must be in the range [0,100]")
+ q /= 100.
# prepare a for partioning
if overwrite_input: