summaryrefslogtreecommitdiff
path: root/numpy/core/fromnumeric.py
diff options
context:
space:
mode:
authormproszewska <38814059+mproszewska@users.noreply.github.com>2019-11-04 23:18:45 +0100
committerMatti Picus <matti.picus@gmail.com>2019-11-04 17:18:45 -0500
commit7b8513fff4c3673d8a967a1d72a73feab2072e8f (patch)
treeb89130fbe9ed9072a42e1b40e6873416183a28f0 /numpy/core/fromnumeric.py
parentcadb066b40342a4c4e6607aa5628ab1825ed87cf (diff)
downloadnumpy-7b8513fff4c3673d8a967a1d72a73feab2072e8f.tar.gz
DOC: Add take_along_axis to the see also section in argmin, argmax etc. (#14799)
* Add take_along_axis to the see also section in argmin, argmax, argsort and argpartition and add examples
Diffstat (limited to 'numpy/core/fromnumeric.py')
-rw-r--r--numpy/core/fromnumeric.py38
1 files changed, 37 insertions, 1 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 5f7716455..6e5f3dabf 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -796,7 +796,9 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
--------
partition : Describes partition algorithms used.
ndarray.partition : Inplace partition.
- argsort : Full indirect sort
+ argsort : Full indirect sort.
+ take_along_axis : Apply ``index_array`` from argpartition
+ to an array as if by calling partition.
Notes
-----
@@ -816,6 +818,14 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
>>> np.array(x)[np.argpartition(x, 3)]
array([2, 1, 3, 4])
+ Multi-dimensional array:
+
+ >>> x = np.array([[3, 4, 2], [1, 3, 1]])
+ >>> index_array = np.argpartition(x, kth=1, axis=-1)
+ >>> np.take_along_axis(x, index_array, axis=-1) # same as np.partition(x, kth=1)
+ array([[2, 3, 4],
+ [1, 1, 3]])
+
"""
return _wrapfunc(a, 'argpartition', kth, axis=axis, kind=kind, order=order)
@@ -1025,6 +1035,8 @@ def argsort(a, axis=-1, kind=None, order=None):
lexsort : Indirect stable sort with multiple keys.
ndarray.sort : Inplace sort.
argpartition : Indirect partial sort.
+ take_along_axis : Apply ``index_array`` from argsort
+ to an array as if by calling sort.
Notes
-----
@@ -1120,6 +1132,8 @@ def argmax(a, axis=None, out=None):
ndarray.argmax, argmin
amax : The maximum value along a given axis.
unravel_index : Convert a flat index into an index tuple.
+ take_along_axis : Apply ``np.expand_dims(index_array, axis)``
+ from argmax to an array as if by calling max.
Notes
-----
@@ -1154,6 +1168,16 @@ def argmax(a, axis=None, out=None):
>>> np.argmax(b) # Only the first occurrence is returned.
1
+ >>> x = np.array([[4,2,3], [1,0,3]])
+ >>> index_array = np.argmax(x, axis=-1)
+ >>> # Same as np.max(x, axis=-1, keepdims=True)
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
+ array([[4],
+ [3]])
+ >>> # Same as np.max(x, axis=-1)
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
+ array([4, 3])
+
"""
return _wrapfunc(a, 'argmax', axis=axis, out=out)
@@ -1189,6 +1213,8 @@ def argmin(a, axis=None, out=None):
ndarray.argmin, argmax
amin : The minimum value along a given axis.
unravel_index : Convert a flat index into an index tuple.
+ take_along_axis : Apply ``np.expand_dims(index_array, axis)``
+ from argmin to an array as if by calling min.
Notes
-----
@@ -1223,6 +1249,16 @@ def argmin(a, axis=None, out=None):
>>> np.argmin(b) # Only the first occurrence is returned.
0
+ >>> x = np.array([[4,2,3], [1,0,3]])
+ >>> index_array = np.argmin(x, axis=-1)
+ >>> # Same as np.min(x, axis=-1, keepdims=True)
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1)
+ array([[2],
+ [0]])
+ >>> # Same as np.max(x, axis=-1)
+ >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
+ array([2, 0])
+
"""
return _wrapfunc(a, 'argmin', axis=axis, out=out)