diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 73 |
1 files changed, 43 insertions, 30 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 92d52109e..b7f1f16f2 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,7 +1,4 @@ -from __future__ import division, absolute_import, print_function - import functools -import warnings import numpy.core.numeric as _nx from numpy.core.numeric import ( @@ -11,6 +8,7 @@ from numpy.core.fromnumeric import reshape, transpose from numpy.core.multiarray import normalize_axis_index from numpy.core import overrides from numpy.core import vstack, atleast_3d +from numpy.core.numeric import normalize_axis_tuple from numpy.core.shape_base import _arrays_for_stack_dispatcher from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -29,7 +27,7 @@ array_function_dispatch = functools.partial( def _make_along_axis_idx(arr_shape, indices, axis): - # compute dimensions to iterate over + # compute dimensions to iterate over if not _nx.issubdtype(indices.dtype, _nx.integer): raise IndexError('`indices` must be an integer array') if len(arr_shape) != indices.ndim: @@ -271,8 +269,8 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): """ Apply a function to 1-D slices along the given axis. - Execute `func1d(a, *args)` where `func1d` operates on 1-D arrays and `a` - is a 1-D slice of `arr` along `axis`. + Execute `func1d(a, *args, **kwargs)` where `func1d` operates on 1-D arrays + and `a` is a 1-D slice of `arr` along `axis`. This is equivalent to (but faster than) the following use of `ndindex` and `s_`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of indices:: @@ -517,22 +515,26 @@ def expand_dims(a, axis): Insert a new axis that will appear at the `axis` position in the expanded array shape. - .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor - ``axis > a.ndim`` raised errors or put the new axis where documented. - Those axis values are now deprecated and will raise an AxisError in the - future. - Parameters ---------- a : array_like Input array. - axis : int - Position in the expanded axes where the new axis is placed. + axis : int or tuple of ints + Position in the expanded axes where the new axis (or axes) is placed. + + .. deprecated:: 1.13.0 + Passing an axis where ``axis > a.ndim`` will be treated as + ``axis == a.ndim``, and passing ``axis < -a.ndim - 1`` will + be treated as ``axis == 0``. This behavior is deprecated. + + .. versionchanged:: 1.18.0 + A tuple of axes is now supported. Out of range axes as + described above are now forbidden and raise an `AxisError`. Returns ------- - res : ndarray - View of `a` with the number of dimensions increased by one. + result : ndarray + View of `a` with the number of dimensions increased. See Also -------- @@ -542,11 +544,11 @@ def expand_dims(a, axis): Examples -------- - >>> x = np.array([1,2]) + >>> x = np.array([1, 2]) >>> x.shape (2,) - The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``: + The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``: >>> y = np.expand_dims(x, axis=0) >>> y @@ -554,13 +556,26 @@ def expand_dims(a, axis): >>> y.shape (1, 2) - >>> y = np.expand_dims(x, axis=1) # Equivalent to x[:,np.newaxis] + The following is equivalent to ``x[:, np.newaxis]``: + + >>> y = np.expand_dims(x, axis=1) >>> y array([[1], [2]]) >>> y.shape (2, 1) + ``axis`` may also be a tuple: + + >>> y = np.expand_dims(x, axis=(0, 1)) + >>> y + array([[[1, 2]]]) + + >>> y = np.expand_dims(x, axis=(2, 0)) + >>> y + array([[[1], + [2]]]) + Note that some examples may use ``None`` instead of ``np.newaxis``. These are the same objects: @@ -573,18 +588,16 @@ def expand_dims(a, axis): else: a = asanyarray(a) - shape = a.shape - if axis > a.ndim or axis < -a.ndim - 1: - # 2017-05-17, 1.13.0 - warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are " - "deprecated and will raise an AxisError in the future.", - DeprecationWarning, stacklevel=3) - # When the deprecation period expires, delete this if block, - if axis < 0: - axis = axis + a.ndim + 1 - # and uncomment the following line. - # axis = normalize_axis_index(axis, a.ndim + 1) - return a.reshape(shape[:axis] + (1,) + shape[axis:]) + if type(axis) not in (tuple, list): + axis = (axis,) + + out_ndim = len(axis) + a.ndim + axis = normalize_axis_tuple(axis, out_ndim) + + shape_it = iter(a.shape) + shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] + + return a.reshape(shape) row_stack = vstack |