diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index d611dd225..b8ae9a470 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -388,12 +388,14 @@ def iterable(y): return True -def _average_dispatcher(a, axis=None, weights=None, returned=None): +def _average_dispatcher(a, axis=None, weights=None, returned=None, *, + keepdims=None): return (a, weights) @array_function_dispatch(_average_dispatcher) -def average(a, axis=None, weights=None, returned=False): +def average(a, axis=None, weights=None, returned=False, *, + keepdims=np._NoValue): """ Compute the weighted average along the specified axis. @@ -428,6 +430,14 @@ def average(a, axis=None, weights=None, returned=False): is returned, otherwise only the average is returned. If `weights=None`, `sum_of_weights` is equivalent to the number of elements over which the average is taken. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + *Note:* `keepdims` will not work with instances of `numpy.matrix` + or other classes whose methods do not support `keepdims`. + + .. versionadded:: 1.23.0 Returns ------- @@ -471,7 +481,7 @@ def average(a, axis=None, weights=None, returned=False): >>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1)) 4.0 - >>> data = np.arange(6).reshape((3,2)) + >>> data = np.arange(6).reshape((3, 2)) >>> data array([[0, 1], [2, 3], @@ -488,11 +498,24 @@ def average(a, axis=None, weights=None, returned=False): >>> avg = np.average(a, weights=w) >>> print(avg.dtype) complex256 + + With ``keepdims=True``, the following result has shape (3, 1). + + >>> np.average(data, axis=1, keepdims=True) + array([[0.5], + [2.5], + [4.5]]) """ a = np.asanyarray(a) + if keepdims is np._NoValue: + # Don't pass on the keepdims argument if one wasn't given. + keepdims_kw = {} + else: + keepdims_kw = {'keepdims': keepdims} + if weights is None: - avg = a.mean(axis) + avg = a.mean(axis, **keepdims_kw) scl = avg.dtype.type(a.size/avg.size) else: wgt = np.asanyarray(weights) @@ -524,7 +547,8 @@ def average(a, axis=None, weights=None, returned=False): raise ZeroDivisionError( "Weights sum to zero, can't be normalized") - avg = np.multiply(a, wgt, dtype=result_dtype).sum(axis)/scl + avg = np.multiply(a, wgt, + dtype=result_dtype).sum(axis, **keepdims_kw) / scl if returned: if scl.shape != avg.shape: |