summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorwarren <warren.weckesser@gmail.com>2022-11-24 13:07:26 -0500
committerwarren <warren.weckesser@gmail.com>2022-11-24 13:07:26 -0500
commit09154cfa6dea50a6ac24ae1062095c9e98026bbc (patch)
treebbd4fae95df38e6320c75de756ece68c1cb8c87a /numpy/lib
parenta872fd73e6e94727c7acf281b03789bd42cda086 (diff)
downloadnumpy-09154cfa6dea50a6ac24ae1062095c9e98026bbc.tar.gz
DOC: lib: Use keepdims in a couple docstrings.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index ab91423d9..b8b3ebaad 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -124,19 +124,21 @@ def take_along_axis(arr, indices, axis):
>>> np.sort(a, axis=1)
array([[10, 20, 30],
[40, 50, 60]])
- >>> ai = np.argsort(a, axis=1); ai
+ >>> ai = np.argsort(a, axis=1)
+ >>> ai
array([[0, 2, 1],
[1, 2, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
[40, 50, 60]])
- The same works for max and min, if you expand the dimensions:
+ The same works for max and min, if you maintain the trivial dimension
+ with ``keepdims``:
- >>> np.expand_dims(np.max(a, axis=1), axis=1)
+ >>> np.max(a, axis=1, keepdims=True)
array([[30],
[60]])
- >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
+ >>> ai = np.argmax(a, axis=1, keepdims=True)
>>> ai
array([[1],
[0]])
@@ -147,8 +149,8 @@ def take_along_axis(arr, indices, axis):
If we want to get the max and min at the same time, we can stack the
indices first
- >>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1)
- >>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1)
+ >>> ai_min = np.argmin(a, axis=1, keepdims=True)
+ >>> ai_max = np.argmax(a, axis=1, keepdims=True)
>>> ai = np.concatenate([ai_min, ai_max], axis=1)
>>> ai
array([[0, 1],
@@ -237,7 +239,7 @@ def put_along_axis(arr, indices, values, axis):
We can replace the maximum values with:
- >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
+ >>> ai = np.argmax(a, axis=1, keepdims=True)
>>> ai
array([[1],
[0]])