summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py178
1 files changed, 112 insertions, 66 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 5742da2bc..edce15776 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2695,6 +2695,66 @@ def msort(a):
return b
+def _ureduce(a, func, **kwargs):
+ """
+ Internal Function.
+ Call `func` with `a` as first argument swapping the axes to use extended
+ axis on functions that don't support it natively.
+
+ Returns result and a.shape with axis dims set to 1.
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ func : callable
+ Reduction function Kapable of receiving an axis argument.
+ It is is called with `a` as first argument followed by `kwargs`.
+ kwargs : keyword arguments
+ additional keyword arguments to pass to `func`.
+
+ Returns
+ -------
+ result : tuple
+ Result of func(a, **kwargs) and a.shape with axis dims set to 1
+ which can be used to reshape the result to the same shape a ufunc with
+ keepdims=True would produce.
+
+ """
+ a = np.asanyarray(a)
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ keepdim = list(a.shape)
+ nd = a.ndim
+ try:
+ axis = operator.index(axis)
+ if axis >= nd or axis < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+ keepdim[axis] = 1
+ except TypeError:
+ sax = set()
+ for x in axis:
+ if x >= nd or x < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (x, nd))
+ if x in sax:
+ raise ValueError("duplicate value in axis")
+ sax.add(x % nd)
+ keepdim[x] = 1
+ keep = sax.symmetric_difference(frozenset(range(nd)))
+ nkeep = len(keep)
+ # swap axis that should not be reduced to front
+ for i, s in enumerate(sorted(keep)):
+ a = a.swapaxes(i, s)
+ # merge reduced axis
+ a = a.reshape(a.shape[:nkeep] + (-1,))
+ kwargs['axis'] = -1
+ else:
+ keepdim = [1] * a.ndim
+
+ r = func(a, **kwargs)
+ return r, keepdim
+
+
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Compute the median along the specified axis.
@@ -2777,76 +2837,62 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
>>> assert not np.all(a==b)
"""
- a = np.asanyarray(a)
- if a.size % 2 == 0:
- return percentile(a, q=50., axis=axis, out=out,
- overwrite_input=overwrite_input,
- interpolation="linear", keepdims=keepdims)
+ r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ overwrite_input=overwrite_input)
+ if keepdims:
+ return r.reshape(k)
else:
- # avoid slower weighting path, relevant for small arrays
- return percentile(a, q=50., axis=axis, out=out,
- overwrite_input=overwrite_input,
- interpolation="nearest", keepdims=keepdims)
-
-
-def _ureduce(a, func, **kwargs):
- """
- Internal Function.
- Call `func` with `a` as first argument swapping the axes to use extended
- axis on functions that don't support it natively.
-
- Returns result and a.shape with axis dims set to 1.
-
- Parameters
- ----------
- a : array_like
- Input array or object that can be converted to an array.
- func : callable
- Reduction function Kapable of receiving an axis argument.
- It is is called with `a` as first argument followed by `kwargs`.
- kwargs : keyword arguments
- additional keyword arguments to pass to `func`.
-
- Returns
- -------
- result : tuple
- Result of func(a, **kwargs) and a.shape with axis dims set to 1
- suiteable to use to archive the same result as the ufunc keepdims
- argument.
+ return r
- """
+def _median(a, axis=None, out=None, overwrite_input=False):
+ # can't be reasonably be implemented in terms of percentile as we have to
+ # call mean to not break astropy
a = np.asanyarray(a)
- axis = kwargs.get('axis', None)
- if axis is not None:
- keepdim = list(a.shape)
- nd = a.ndim
- try:
- axis = operator.index(axis)
- if axis >= nd or axis < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
- keepdim[axis] = 1
- except TypeError:
- sax = set()
- for x in axis:
- if x >= nd or x < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (x, nd))
- if x in sax:
- raise ValueError("duplicate value in axis")
- sax.add(x % nd)
- keepdim[x] = 1
- keep = sax.symmetric_difference(frozenset(range(nd)))
- nkeep = len(keep)
- # swap axis that should not be reduced to front
- for i, s in enumerate(sorted(keep)):
- a = a.swapaxes(i, s)
- # merge reduced axis
- a = a.reshape(a.shape[:nkeep] + (-1,))
- kwargs['axis'] = -1
- else:
- keepdim = [1] * a.ndim
+ if axis is not None and axis >= a.ndim:
+ raise IndexError(
+ "axis %d out of bounds (%d)" % (axis, a.ndim))
- r = func(a, **kwargs)
- return r, keepdim
+ if overwrite_input:
+ if axis is None:
+ part = a.ravel()
+ sz = part.size
+ if sz % 2 == 0:
+ szh = sz // 2
+ part.partition((szh - 1, szh))
+ else:
+ part.partition((sz - 1) // 2)
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ szh = sz // 2
+ a.partition((szh - 1, szh), axis=axis)
+ else:
+ a.partition((sz - 1) // 2, axis=axis)
+ part = a
+ else:
+ if axis is None:
+ sz = a.size
+ else:
+ sz = a.shape[axis]
+ if sz % 2 == 0:
+ part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
+ else:
+ part = partition(a, (sz - 1) // 2, axis=axis)
+ if part.shape == ():
+ # make 0-D arrays work
+ return part.item()
+ if axis is None:
+ axis = 0
+ indexer = [slice(None)] * part.ndim
+ index = part.shape[axis] // 2
+ if part.shape[axis] % 2 == 1:
+ # index with slice to allow mean (below) to work
+ indexer[axis] = slice(index, index+1)
+ else:
+ indexer[axis] = slice(index-1, index+1)
+ # Use mean in odd and even case to coerce data type
+ # and check, use out array.
+ return mean(part[indexer], axis=axis, out=out)
def percentile(a, q, axis=None, out=None,