diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 830943e72..53578e0e4 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -85,11 +85,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): array([[[1, 0, 0], [0, 2, 0], [0, 0, 3]], - [[4, 0, 0], [0, 5, 0], [0, 0, 6]], - [[7, 0, 0], [0, 8, 0], [0, 0, 9]]]) @@ -240,14 +238,20 @@ def expand_dims(a, axis): """ Expand the shape of an array. - Insert a new axis, corresponding to a given position in the array shape. + Insert a new axis that will appear at the `axis` position in the expanded + array shape. + + .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor + ``axis > a.ndim`` raised errors or put the new axis where documented. + Those axis values are now deprecated and will raise an AxisError in the + future. Parameters ---------- a : array_like Input array. axis : int - Position (amongst axes) where new axis is to be inserted. + Position in the expanded axes where the new axis is placed. Returns ------- @@ -291,7 +295,16 @@ def expand_dims(a, axis): """ a = asarray(a) shape = a.shape - axis = normalize_axis_index(axis, a.ndim + 1) + if axis > a.ndim or axis < -a.ndim - 1: + # 2017-05-17, 1.13.0 + warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are " + "deprecated and will raise an AxisError in the future.", + DeprecationWarning, stacklevel=2) + # When the deprecation period expires, delete this if block, + if axis < 0: + axis = axis + a.ndim + 1 + # and uncomment the following line. + # axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) row_stack = vstack @@ -317,7 +330,7 @@ def column_stack(tup): See Also -------- - hstack, vstack, concatenate + stack, hstack, vstack, concatenate Examples -------- |