summaryrefslogtreecommitdiff
path: root/numpy/core/fromnumeric.py
diff options
context:
space:
mode:
authorslepton <slepton@posteo.de>2021-07-07 21:25:53 +0200
committerslepton <slepton@posteo.de>2021-07-07 21:25:53 +0200
commitd19584547a7934bfce9dee445e65a2c1994cecdc (patch)
tree2013779442bff7d9f47464fd3a1b6ba755b0070d /numpy/core/fromnumeric.py
parent5ac75798362cea6ecbd602a46b80c297b0f6712a (diff)
parentde245cd133699f8c23f97ec07ec29703e37a5923 (diff)
downloadnumpy-d19584547a7934bfce9dee445e65a2c1994cecdc.tar.gz
Merge remote-tracking branch 'origin/main' into main
Diffstat (limited to 'numpy/core/fromnumeric.py')
-rw-r--r--numpy/core/fromnumeric.py42
1 files changed, 34 insertions, 8 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index ee93da901..7164a2c28 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -1114,12 +1114,12 @@ def argsort(a, axis=-1, kind=None, order=None):
return _wrapfunc(a, 'argsort', axis=axis, kind=kind, order=order)
-def _argmax_dispatcher(a, axis=None, out=None):
+def _argmax_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
return (a, out)
@array_function_dispatch(_argmax_dispatcher)
-def argmax(a, axis=None, out=None):
+def argmax(a, axis=None, out=None, *, keepdims=np._NoValue):
"""
Returns the indices of the maximum values along an axis.
@@ -1133,12 +1133,18 @@ def argmax(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the array.
Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
- with the dimension along `axis` removed.
+ with the dimension along `axis` removed. If `keepdims` is set to True,
+ then the size of `axis` will be 1 with the resulting array having same
+ shape as `a.shape`.
See Also
--------
@@ -1191,16 +1197,23 @@ def argmax(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([4, 3])
+ Setting `keepdims` to `True`,
+
+ >>> x = np.arange(24).reshape((2, 3, 4))
+ >>> res = np.argmax(x, axis=1, keepdims=True)
+ >>> res.shape
+ (2, 1, 4)
"""
- return _wrapfunc(a, 'argmax', axis=axis, out=out)
+ kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
+ return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
-def _argmin_dispatcher(a, axis=None, out=None):
+def _argmin_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
return (a, out)
@array_function_dispatch(_argmin_dispatcher)
-def argmin(a, axis=None, out=None):
+def argmin(a, axis=None, out=None, *, keepdims=np._NoValue):
"""
Returns the indices of the minimum values along an axis.
@@ -1214,12 +1227,18 @@ def argmin(a, axis=None, out=None):
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
+ keepdims : bool, optional
+ If this is set to True, the axes which are reduced are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the array.
Returns
-------
index_array : ndarray of ints
Array of indices into the array. It has the same shape as `a.shape`
- with the dimension along `axis` removed.
+ with the dimension along `axis` removed. If `keepdims` is set to True,
+ then the size of `axis` will be 1 with the resulting array having same
+ shape as `a.shape`.
See Also
--------
@@ -1272,8 +1291,15 @@ def argmin(a, axis=None, out=None):
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
array([2, 0])
+ Setting `keepdims` to `True`,
+
+ >>> x = np.arange(24).reshape((2, 3, 4))
+ >>> res = np.argmin(x, axis=1, keepdims=True)
+ >>> res.shape
+ (2, 1, 4)
"""
- return _wrapfunc(a, 'argmin', axis=axis, out=out)
+ kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
+ return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)
def _searchsorted_dispatcher(a, v, side=None, sorter=None):