summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-03-12 20:44:14 +0100
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-03-13 21:10:31 +0100
commit7d53c812d3e68ff28320ee7e32bc9816937b4142 (patch)
tree8906807748186d9d5217a15656cc01059dacf3bc /numpy/lib/function_base.py
parenteea1a9c49024c18fda3ad9782dee3956492cfa1a (diff)
downloadnumpy-7d53c812d3e68ff28320ee7e32bc9816937b4142.tar.gz
MAINT: revert back to separate median implementation
Merging median and percentile make would break astropy and quantities as we don't call mean anymore. These packages rely on overriding mean to add their own median behavior.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py178
1 files changed, 112 insertions, 66 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 5742da2bc..edce15776 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2695,6 +2695,66 @@ def msort(a):
return b
+def _ureduce(a, func, **kwargs):
+ """
+ Internal Function.
+ Call `func` with `a` as first argument swapping the axes to use extended
+ axis on functions that don't support it natively.
+
+ Returns result and a.shape with axis dims set to 1.
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ func : callable
+ Reduction function Kapable of receiving an axis argument.
+ It is is called with `a` as first argument followed by `kwargs`.
+ kwargs : keyword arguments
+ additional keyword arguments to pass to `func`.
+
+ Returns
+ -------
+ result : tuple
+ Result of func(a, **kwargs) and a.shape with axis dims set to 1
+ which can be used to reshape the result to the same shape a ufunc with
+ keepdims=True would produce.
+
+ """
+ a = np.asanyarray(a)
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ keepdim = list(a.shape)
+ nd = a.ndim
+ try:
+ axis = operator.index(axis)
+ if axis >= nd or axis < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+ keepdim[axis] = 1
+ except TypeError:
+ sax = set()
+ for x in axis:
+ if x >= nd or x < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (x, nd))
+ if x in sax:
+ raise ValueError("duplicate value in axis")
+ sax.add(x % nd)
+ keepdim[x] = 1
+ keep = sax.symmetric_difference(frozenset(range(nd)))
+ nkeep = len(keep)
+ # swap axis that should not be reduced to front
+ for i, s in enumerate(sorted(keep)):
+ a = a.swapaxes(i, s)
+ # merge reduced axis
+ a = a.reshape(a.shape[:nkeep] + (-1,))
+ kwargs['axis'] = -1
+ else:
+ keepdim = [1] * a.ndim
+
+ r = func(a, **kwargs)
+ return r, keepdim
+
+
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Compute the median along the specified axis.
@@ -2777,76 +2837,62 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
>>> assert not np.all(a==b)
"""
- a = np.asanyarray(a)
- if a.size % 2 == 0:
- return percentile(a, q=50., axis=axis, out=out,
- overwrite_input=overwrite_input,
- interpolation="linear", keepdims=keepdims)
+ r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ overwrite_input=overwrite_input)
+ if keepdims:
+ return r.reshape(k)
else:
- # avoid slower weighting path, relevant for small arrays
- return percentile(a, q=50., axis=axis, out=out,
- overwrite_input=overwrite_input,
- interpolation="nearest", keepdims=keepdims)
-
-
-def _ureduce(a, func, **kwargs):
- """
- Internal Function.
- Call `func` with `a` as first argument swapping the axes to use extended
- axis on functions that don't support it natively.
-
- Returns result and a.shape with axis dims set to 1.
-
- Parameters
- ----------
- a : array_like
- Input array or object that can be converted to an array.
- func : callable
- Reduction function Kapable of receiving an axis argument.
- It is is called with `a` as first argument followed by `kwargs`.
- kwargs : keyword arguments
- additional keyword arguments to pass to `func`.
-
- Returns
- -------
- result : tuple
- Result of func(a, **kwargs) and a.shape with axis dims set to 1
- suiteable to use to archive the same result as the ufunc keepdims
- argument.
+ return r
- """
+def _median(a, axis=None, out=None, overwrite_input=False):
+ # can't be reasonably be implemented in terms of percentile as we have to
+ # call mean to not break astropy
a = np.asanyarray(a)
- axis = kwargs.get('axis', None)
- if axis is not None:
- keepdim = list(a.shape)
- nd = a.ndim
- try:
- axis = operator.index(axis)
- if axis >= nd or axis < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
- keepdim[axis] = 1
- except TypeError:
- sax = set()
- for x in axis:
- if x >= nd or x < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (x, nd))
- if x in sax:
- raise ValueError("duplicate value in axis")
- sax.add(x % nd)
- keepdim[x] = 1
- keep = sax.symmetric_difference(frozenset(range(nd)))
- nkeep = len(keep)
- # swap axis that should not be reduced to front
- for i, s in enumerate(sorted(keep)):
- a = a.swapaxes(i, s)
- # merge reduced axis
- a = a.reshape(a.shape[:nkeep] + (-1,))
- kwargs['axis'] = -1
- else:
- keepdim = [1] * a.ndim
+ if axis is not None and axis >= a.ndim:
+ raise IndexError(
+ "axis %d out of bounds (%d)" % (axis, a.ndim))
- r = func(a, **kwargs)
- return r, keepdim
+ if overwrite_input:
+ if axis is None:
+ part = a.ravel()
+ sz = part.size
+ if sz % 2 == 0:
+ szh = sz // 2
+ part.partition((szh - 1, szh))
+ else:
+ part.partition((sz - 1) // 2)
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ a.partition((szh - 1, szh), axis=axis)
+ else:
+ a.partition((sz - 1) // 2, axis=axis)
+ part = a
+ else:
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
+ else:
+ part = partition(a, (sz - 1) // 2, axis=axis)
+ if part.shape == ():
+ # make 0-D arrays work
+ return part.item()
+ if axis is None:
+ axis = 0
+ indexer = [slice(None)] * part.ndim
+ index = part.shape[axis] // 2
+ if part.shape[axis] % 2 == 1:
+ # index with slice to allow mean (below) to work
+ indexer[axis] = slice(index, index+1)
+ else:
+ indexer[axis] = slice(index-1, index+1)
+ # Use mean in odd and even case to coerce data type
+ # and check, use out array.
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None,