diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 32 |
1 files changed, 13 insertions, 19 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index ab91423d9..5d8a41bfe 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,9 +1,7 @@ import functools import numpy.core.numeric as _nx -from numpy.core.numeric import ( - asarray, zeros, outer, concatenate, array, asanyarray - ) +from numpy.core.numeric import asarray, zeros, array, asanyarray from numpy.core.fromnumeric import reshape, transpose from numpy.core.multiarray import normalize_axis_index from numpy.core import overrides @@ -124,19 +122,21 @@ def take_along_axis(arr, indices, axis): >>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) - >>> ai = np.argsort(a, axis=1); ai + >>> ai = np.argsort(a, axis=1) + >>> ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]]) - The same works for max and min, if you expand the dimensions: + The same works for max and min, if you maintain the trivial dimension + with ``keepdims``: - >>> np.expand_dims(np.max(a, axis=1), axis=1) + >>> np.max(a, axis=1, keepdims=True) array([[30], [60]]) - >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) + >>> ai = np.argmax(a, axis=1, keepdims=True) >>> ai array([[1], [0]]) @@ -147,8 +147,8 @@ def take_along_axis(arr, indices, axis): If we want to get the max and min at the same time, we can stack the indices first - >>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1) - >>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1) + >>> ai_min = np.argmin(a, axis=1, keepdims=True) + >>> ai_max = np.argmax(a, axis=1, keepdims=True) >>> ai = np.concatenate([ai_min, ai_max], axis=1) >>> ai array([[0, 1], @@ -237,7 +237,7 @@ def put_along_axis(arr, indices, values, axis): We can replace the maximum values with: - >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) + >>> ai = np.argmax(a, axis=1, keepdims=True) >>> ai array([[1], [0]]) @@ -372,7 +372,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): # invoke the function on the first item try: ind0 = next(inds) - except StopIteration as e: + except StopIteration: raise ValueError( 'Cannot apply_along_axis when any iteration dimensions are 0' ) from None @@ -643,10 +643,6 @@ def column_stack(tup): [3, 4]]) """ - if not overrides.ARRAY_FUNCTION_ENABLED: - # raise warning if necessary - _arrays_for_stack_dispatcher(tup, stacklevel=2) - arrays = [] for v in tup: arr = asanyarray(v) @@ -713,10 +709,6 @@ def dstack(tup): [[3, 4]]]) """ - if not overrides.ARRAY_FUNCTION_ENABLED: - # raise warning if necessary - _arrays_for_stack_dispatcher(tup, stacklevel=2) - arrs = atleast_3d(*tup) if not isinstance(arrs, list): arrs = [arrs] @@ -1041,6 +1033,7 @@ def dsplit(ary, indices_or_sections): raise ValueError('dsplit only works on arrays of 3 or more dimensions') return split(ary, indices_or_sections, 2) + def get_array_prepare(*args): """Find the wrapper for the array with the highest priority. @@ -1053,6 +1046,7 @@ def get_array_prepare(*args): return wrappers[-1][-1] return None + def get_array_wrap(*args): """Find the wrapper for the array with the highest priority. |