summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py34
1 files changed, 29 insertions, 5 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index d611dd225..b8ae9a470 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -388,12 +388,14 @@ def iterable(y):
return True
-def _average_dispatcher(a, axis=None, weights=None, returned=None):
+def _average_dispatcher(a, axis=None, weights=None, returned=None, *,
+ keepdims=None):
return (a, weights)
@array_function_dispatch(_average_dispatcher)
-def average(a, axis=None, weights=None, returned=False):
+def average(a, axis=None, weights=None, returned=False, *,
+ keepdims=np._NoValue):
"""
Compute the weighted average along the specified axis.
@@ -428,6 +430,14 @@ def average(a, axis=None, weights=None, returned=False):
is returned, otherwise only the average is returned.
If `weights=None`, `sum_of_weights` is equivalent to the number of
elements over which the average is taken.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the original `a`.
+ *Note:* `keepdims` will not work with instances of `numpy.matrix`
+ or other classes whose methods do not support `keepdims`.
+
+ .. versionadded:: 1.23.0
Returns
-------
@@ -471,7 +481,7 @@ def average(a, axis=None, weights=None, returned=False):
>>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1))
4.0
- >>> data = np.arange(6).reshape((3,2))
+ >>> data = np.arange(6).reshape((3, 2))
>>> data
array([[0, 1],
[2, 3],
@@ -488,11 +498,24 @@ def average(a, axis=None, weights=None, returned=False):
>>> avg = np.average(a, weights=w)
>>> print(avg.dtype)
complex256
+
+ With ``keepdims=True``, the following result has shape (3, 1).
+
+ >>> np.average(data, axis=1, keepdims=True)
+ array([[0.5],
+ [2.5],
+ [4.5]])
"""
a = np.asanyarray(a)
+ if keepdims is np._NoValue:
+ # Don't pass on the keepdims argument if one wasn't given.
+ keepdims_kw = {}
+ else:
+ keepdims_kw = {'keepdims': keepdims}
+
if weights is None:
- avg = a.mean(axis)
+ avg = a.mean(axis, **keepdims_kw)
scl = avg.dtype.type(a.size/avg.size)
else:
wgt = np.asanyarray(weights)
@@ -524,7 +547,8 @@ def average(a, axis=None, weights=None, returned=False):
raise ZeroDivisionError(
"Weights sum to zero, can't be normalized")
- avg = np.multiply(a, wgt, dtype=result_dtype).sum(axis)/scl
+ avg = np.multiply(a, wgt,
+ dtype=result_dtype).sum(axis, **keepdims_kw) / scl
if returned:
if scl.shape != avg.shape: