summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2013-10-13 12:34:43 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-03-13 19:06:20 +0100
commiteea1a9c49024c18fda3ad9782dee3956492cfa1a (patch)
tree559e94b82e8c4ea05eef34f6421d56ab5d46ad97 /numpy/lib
parent9c4602f98096abed5632d5fd1f12549a2c5b360c (diff)
downloadnumpy-eea1a9c49024c18fda3ad9782dee3956492cfa1a.tar.gz
ENH: add extended axis and keepdims support to median and percentile
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py165
-rw-r--r--numpy/lib/tests/test_function_base.py135
2 files changed, 231 insertions, 69 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 7b2d077c3..5742da2bc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -13,6 +13,7 @@ __all__ = [
import warnings
import sys
import collections
+import operator
import numpy as np
import numpy.core.numeric as _nx
@@ -2694,7 +2695,7 @@ def msort(a):
return b
-def median(a, axis=None, out=None, overwrite_input=False):
+def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Compute the median along the specified axis.
@@ -2704,9 +2705,10 @@ def median(a, axis=None, out=None, overwrite_input=False):
----------
a : array_like
Input array or object that can be converted to an array.
- axis : int, optional
+ axis : int or sequence of int, optional
Axis along which the medians are computed. The default (axis=None)
is to compute the median along a flattened version of the array.
+ A sequence of axes is supported since version 1.9.0.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and buffer length as the expected output, but the
@@ -2719,6 +2721,13 @@ def median(a, axis=None, out=None, overwrite_input=False):
will probably be fully or partially sorted. Default is False. Note
that, if `overwrite_input` is True and the input is not already an
ndarray, an error will be raised.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the original `arr`.
+
+ .. versionadded:: 1.9.0
+
Returns
-------
@@ -2769,55 +2778,79 @@ def median(a, axis=None, out=None, overwrite_input=False):
"""
a = np.asanyarray(a)
- if axis is not None and axis >= a.ndim:
- raise IndexError(
- "axis %d out of bounds (%d)" % (axis, a.ndim))
-
- if overwrite_input:
- if axis is None:
- part = a.ravel()
- sz = part.size
- if sz % 2 == 0:
- szh = sz // 2
- part.partition((szh - 1, szh))
- else:
- part.partition((sz - 1) // 2)
- else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- szh = sz // 2
- a.partition((szh - 1, szh), axis=axis)
- else:
- a.partition((sz - 1) // 2, axis=axis)
- part = a
+ if a.size % 2 == 0:
+ return percentile(a, q=50., axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation="linear", keepdims=keepdims)
else:
- if axis is None:
- sz = a.size
- else:
- sz = a.shape[axis]
- if sz % 2 == 0:
- part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
- else:
- part = partition(a, (sz - 1) // 2, axis=axis)
- if part.shape == ():
- # make 0-D arrays work
- return part.item()
- if axis is None:
- axis = 0
- indexer = [slice(None)] * part.ndim
- index = part.shape[axis] // 2
- if part.shape[axis] % 2 == 1:
- # index with slice to allow mean (below) to work
- indexer[axis] = slice(index, index+1)
+ # avoid slower weighting path, relevant for small arrays
+ return percentile(a, q=50., axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation="nearest", keepdims=keepdims)
+
+
+def _ureduce(a, func, **kwargs):
+ """
+ Internal Function.
+ Call `func` with `a` as first argument swapping the axes to use extended
+ axis on functions that don't support it natively.
+
+ Returns result and a.shape with axis dims set to 1.
+
+ Parameters
+ ----------
+ a : array_like
+ Input array or object that can be converted to an array.
+ func : callable
+ Reduction function Kapable of receiving an axis argument.
+ It is is called with `a` as first argument followed by `kwargs`.
+ kwargs : keyword arguments
+ additional keyword arguments to pass to `func`.
+
+ Returns
+ -------
+ result : tuple
+ Result of func(a, **kwargs) and a.shape with axis dims set to 1
+ suiteable to use to archive the same result as the ufunc keepdims
+ argument.
+
+ """
+ a = np.asanyarray(a)
+ axis = kwargs.get('axis', None)
+ if axis is not None:
+ keepdim = list(a.shape)
+ nd = a.ndim
+ try:
+ axis = operator.index(axis)
+ if axis >= nd or axis < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
+ keepdim[axis] = 1
+ except TypeError:
+ sax = set()
+ for x in axis:
+ if x >= nd or x < -nd:
+ raise IndexError("axis %d out of bounds (%d)" % (x, nd))
+ if x in sax:
+ raise ValueError("duplicate value in axis")
+ sax.add(x % nd)
+ keepdim[x] = 1
+ keep = sax.symmetric_difference(frozenset(range(nd)))
+ nkeep = len(keep)
+ # swap axis that should not be reduced to front
+ for i, s in enumerate(sorted(keep)):
+ a = a.swapaxes(i, s)
+ # merge reduced axis
+ a = a.reshape(a.shape[:nkeep] + (-1,))
+ kwargs['axis'] = -1
else:
- indexer[axis] = slice(index-1, index+1)
- # Use mean in odd and even case to coerce data type
- # and check, use out array.
- return mean(part[indexer], axis=axis, out=out)
+ keepdim = [1] * a.ndim
+
+ r = func(a, **kwargs)
+ return r, keepdim
def percentile(a, q, axis=None, out=None,
- overwrite_input=False, interpolation='linear'):
+ overwrite_input=False, interpolation='linear', keepdims=False):
"""
Compute the qth percentile of the data along the specified axis.
@@ -2829,9 +2862,10 @@ def percentile(a, q, axis=None, out=None,
Input array or object that can be converted to an array.
q : float in range of [0,100] (or sequence of floats)
Percentile to compute which must be between 0 and 100 inclusive.
- axis : int, optional
+ axis : int or sequence of int, optional
Axis along which the percentiles are computed. The default (None)
is to compute the percentiles along a flattened version of the array.
+ A sequence of axes is supported since version 1.9.0.
out : ndarray, optional
Alternative output array in which to place the result. It must
have the same shape and buffer length as the expected output,
@@ -2857,6 +2891,12 @@ def percentile(a, q, axis=None, out=None,
* midpoint: (`i` + `j`) / 2.
.. versionadded:: 1.9.0
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the original `arr`.
+
+ .. versionadded:: 1.9.0
Returns
-------
@@ -2913,8 +2953,22 @@ def percentile(a, q, axis=None, out=None,
array([ 3.5])
"""
+ q = asarray(q, dtype=np.float64)
+ r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out,
+ overwrite_input=overwrite_input,
+ interpolation=interpolation)
+ if keepdims:
+ if q.ndim == 0:
+ return r.reshape(k)
+ else:
+ return r.reshape([len(q)] + k)
+ else:
+ return r
+
+
+def _percentile(a, q, axis=None, out=None,
+ overwrite_input=False, interpolation='linear', keepdims=False):
a = asarray(a)
- q = asarray(q)
if q.ndim == 0:
# Do not allow 0-d arrays because following code fails for scalar
zerod = True
@@ -2922,10 +2976,17 @@ def percentile(a, q, axis=None, out=None,
else:
zerod = False
- q = q / 100.0
- if (q < 0).any() or (q > 1).any():
- raise ValueError(
- "Percentiles must be in the range [0,100]")
+ # avoid expensive reductions, relevant for arrays with < O(1000) elements
+ if q.size < 10:
+ for i in range(q.size):
+ if q[i] < 0. or q[i] > 100.:
+ raise ValueError("Percentiles must be in the range [0,100]")
+ q[i] /= 100.
+ else:
+ # faster than any()
+ if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.):
+ raise ValueError("Percentiles must be in the range [0,100]")
+ q /= 100.
# prepare a for partioning
if overwrite_input:
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 6c11b0385..27e7302ce 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1688,6 +1688,8 @@ class TestScoreatpercentile(TestCase):
interpolation='foobar')
assert_raises(ValueError, np.percentile, [1], 101)
assert_raises(ValueError, np.percentile, [1], -1)
+ assert_raises(ValueError, np.percentile, [1], list(range(50)) + [101])
+ assert_raises(ValueError, np.percentile, [1], list(range(50)) + [-0.1])
def test_percentile_list(self):
assert_equal(np.percentile([1, 2, 3], 0), 1)
@@ -1779,6 +1781,65 @@ class TestScoreatpercentile(TestCase):
b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True)
assert_equal(b, np.array([2.5]))
+ def test_extended_axis(self):
+ o = np.random.normal(size=(71, 23))
+ x = np.dstack([o] * 10)
+ assert_equal(np.percentile(x, 30, axis=(0, 1)), np.percentile(o, 30))
+ x = np.rollaxis(x, -1, 0)
+ assert_equal(np.percentile(x, 30, axis=(-2, -1)), np.percentile(o, 30))
+ x = x.swapaxes(0, 1).copy()
+ assert_equal(np.percentile(x, 30, axis=(0, -1)), np.percentile(o, 30))
+ x = x.swapaxes(0, 1).copy()
+
+ assert_equal(np.percentile(x, [25, 60], axis=(0, 1, 2)),
+ np.percentile(x, [25, 60], axis=None))
+ assert_equal(np.percentile(x, [25, 60], axis=(0,)),
+ np.percentile(x, [25, 60], axis=0))
+
+ d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
+ np.random.shuffle(d)
+ assert_equal(np.percentile(d, 25, axis=(0, 1, 2))[0],
+ np.percentile(d[:, :, :, 0].flatten(), 25))
+ assert_equal(np.percentile(d, [10, 90], axis=(0, 1, 3))[:, 1],
+ np.percentile(d[:, :, 1, :].flatten(), [10, 90]))
+ assert_equal(np.percentile(d, 25, axis=(3, 1, -4))[2],
+ np.percentile(d[:, :, 2, :].flatten(), 25))
+ assert_equal(np.percentile(d, 25, axis=(3, 1, 2))[2],
+ np.percentile(d[2, :, :, :].flatten(), 25))
+ assert_equal(np.percentile(d, 25, axis=(3, 2))[2, 1],
+ np.percentile(d[2, 1, :, :].flatten(), 25))
+ assert_equal(np.percentile(d, 25, axis=(1, -2))[2, 1],
+ np.percentile(d[2, :, :, 1].flatten(), 25))
+ assert_equal(np.percentile(d, 25, axis=(1, 3))[2, 2],
+ np.percentile(d[2, :, 2, :].flatten(), 25))
+
+ def test_extended_axis_invalid(self):
+ d = np.ones((3, 5, 7, 11))
+ assert_raises(IndexError, np.percentile, d, axis=-5, q=25)
+ assert_raises(IndexError, np.percentile, d, axis=(0, -5), q=25)
+ assert_raises(IndexError, np.percentile, d, axis=4, q=25)
+ assert_raises(IndexError, np.percentile, d, axis=(0, 4), q=25)
+ assert_raises(ValueError, np.percentile, d, axis=(1, 1), q=25)
+
+ def test_keepdims(self):
+ d = np.ones((3, 5, 7, 11))
+ assert_equal(np.percentile(d, 7, axis=None, keepdims=True).shape,
+ (1, 1, 1, 1))
+ assert_equal(np.percentile(d, 7, axis=(0, 1), keepdims=True).shape,
+ (1, 1, 7, 11))
+ assert_equal(np.percentile(d, 7, axis=(0, 3), keepdims=True).shape,
+ (1, 5, 7, 1))
+ assert_equal(np.percentile(d, 7, axis=(1,), keepdims=True).shape,
+ (3, 1, 7, 11))
+ assert_equal(np.percentile(d, 7, (0, 1, 2, 3), keepdims=True).shape,
+ (1, 1, 1, 1))
+ assert_equal(np.percentile(d, 7, axis=(0, 1, 3), keepdims=True).shape,
+ (1, 1, 7, 1))
+
+ assert_equal(np.percentile(d, [1, 7], axis=(0, 1, 3),
+ keepdims=True).shape, (2, 1, 1, 7, 1))
+ assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
+ keepdims=True).shape, (2, 1, 5, 7, 1))
class TestMedian(TestCase):
@@ -1786,19 +1847,19 @@ class TestMedian(TestCase):
a0 = np.array(1)
a1 = np.arange(2)
a2 = np.arange(6).reshape(2, 3)
- assert_allclose(np.median(a0), 1)
+ assert_equal(np.median(a0), 1)
assert_allclose(np.median(a1), 0.5)
assert_allclose(np.median(a2), 2.5)
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
- assert_allclose(np.median(a2, axis=1), [1, 4])
+ assert_equal(np.median(a2, axis=1), [1, 4])
assert_allclose(np.median(a2, axis=None), 2.5)
a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
a = np.array([0.0463301, 0.0444502, 0.141249])
- assert_almost_equal(a[0], np.median(a))
+ assert_equal(a[0], np.median(a))
a = np.array([0.0444502, 0.141249, 0.0463301])
- assert_almost_equal(a[-1], np.median(a))
+ assert_equal(a[-1], np.median(a))
def test_axis_keyword(self):
a3 = np.array([[2, 3],
@@ -1858,19 +1919,59 @@ class TestMedian(TestCase):
assert_almost_equal(np.median(x2), 2)
assert_allclose(np.median(x2, axis=0), x)
- def test_subclass(self):
- # gh-3846
- class MySubClass(np.ndarray):
- def __new__(cls, input_array, info=None):
- obj = np.asarray(input_array).view(cls)
- obj.info = info
- return obj
-
- def mean(self, axis=None, dtype=None, out=None):
- return -7
-
- a = MySubClass([1,2,3])
- assert_equal(np.median(a), -7)
+ def test_extended_axis(self):
+ o = np.random.normal(size=(71, 23))
+ x = np.dstack([o] * 10)
+ assert_equal(np.median(x, axis=(0, 1)), np.median(o))
+ x = np.rollaxis(x, -1, 0)
+ assert_equal(np.median(x, axis=(-2, -1)), np.median(o))
+ x = x.swapaxes(0, 1).copy()
+ assert_equal(np.median(x, axis=(0, -1)), np.median(o))
+
+ assert_equal(np.median(x, axis=(0, 1, 2)), np.median(x, axis=None))
+ assert_equal(np.median(x, axis=(0, )), np.median(x, axis=0))
+ assert_equal(np.median(x, axis=(-1, )), np.median(x, axis=-1))
+
+ d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
+ np.random.shuffle(d)
+ assert_equal(np.median(d, axis=(0, 1, 2))[0],
+ np.median(d[:, :, :, 0].flatten()))
+ assert_equal(np.median(d, axis=(0, 1, 3))[1],
+ np.median(d[:, :, 1, :].flatten()))
+ assert_equal(np.median(d, axis=(3, 1, -4))[2],
+ np.median(d[:, :, 2, :].flatten()))
+ assert_equal(np.median(d, axis=(3, 1, 2))[2],
+ np.median(d[2, :, :, :].flatten()))
+ assert_equal(np.median(d, axis=(3, 2))[2, 1],
+ np.median(d[2, 1, :, :].flatten()))
+ assert_equal(np.median(d, axis=(1, -2))[2, 1],
+ np.median(d[2, :, :, 1].flatten()))
+ assert_equal(np.median(d, axis=(1, 3))[2, 2],
+ np.median(d[2, :, 2, :].flatten()))
+
+ def test_extended_axis_invalid(self):
+ d = np.ones((3, 5, 7, 11))
+ assert_raises(IndexError, np.median, d, axis=-5)
+ assert_raises(IndexError, np.median, d, axis=(0, -5))
+ assert_raises(IndexError, np.median, d, axis=4)
+ assert_raises(IndexError, np.median, d, axis=(0, 4))
+ assert_raises(ValueError, np.median, d, axis=(1, 1))
+
+ def test_keepdims(self):
+ d = np.ones((3, 5, 7, 11))
+ assert_equal(np.median(d, axis=None, keepdims=True).shape,
+ (1, 1, 1, 1))
+ assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape,
+ (1, 1, 7, 11))
+ assert_equal(np.median(d, axis=(0, 3), keepdims=True).shape,
+ (1, 5, 7, 1))
+ assert_equal(np.median(d, axis=(1,), keepdims=True).shape,
+ (3, 1, 7, 11))
+ assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape,
+ (1, 1, 1, 1))
+ assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
+ (1, 1, 7, 1))
+
class TestAdd_newdoc_ufunc(TestCase):