diff options
-rw-r--r-- | doc/release/upcoming_changes/14051.expired.rst | 2 | ||||
-rw-r--r-- | doc/release/upcoming_changes/14051.new_feature.rst | 4 | ||||
-rw-r--r-- | numpy/lib/shape_base.py | 67 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 24 |
4 files changed, 65 insertions, 32 deletions
diff --git a/doc/release/upcoming_changes/14051.expired.rst b/doc/release/upcoming_changes/14051.expired.rst new file mode 100644 index 000000000..8e00ae575 --- /dev/null +++ b/doc/release/upcoming_changes/14051.expired.rst @@ -0,0 +1,2 @@ +* The deprecation of ``expand_dims`` out-of-range axes in 1.13.0 has + expired. diff --git a/doc/release/upcoming_changes/14051.new_feature.rst b/doc/release/upcoming_changes/14051.new_feature.rst new file mode 100644 index 000000000..617e06482 --- /dev/null +++ b/doc/release/upcoming_changes/14051.new_feature.rst @@ -0,0 +1,4 @@ +A tuple of axes can now be input to ``expand_dims`` +--------------------------------------------------- +The `numpy.expand_dims` ``axis`` keyword can now accept a tuple of +axes. Previously, ``axis`` was required to be an integer. diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 92d52109e..dbb61c225 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,7 +1,6 @@ from __future__ import division, absolute_import, print_function import functools -import warnings import numpy.core.numeric as _nx from numpy.core.numeric import ( @@ -11,6 +10,7 @@ from numpy.core.fromnumeric import reshape, transpose from numpy.core.multiarray import normalize_axis_index from numpy.core import overrides from numpy.core import vstack, atleast_3d +from numpy.core.numeric import normalize_axis_tuple from numpy.core.shape_base import _arrays_for_stack_dispatcher from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -29,7 +29,7 @@ array_function_dispatch = functools.partial( def _make_along_axis_idx(arr_shape, indices, axis): - # compute dimensions to iterate over + # compute dimensions to iterate over if not _nx.issubdtype(indices.dtype, _nx.integer): raise IndexError('`indices` must be an integer array') if len(arr_shape) != indices.ndim: @@ -517,22 +517,26 @@ def expand_dims(a, axis): Insert a new axis that will appear at the `axis` position in the expanded array shape. - .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor - ``axis > a.ndim`` raised errors or put the new axis where documented. - Those axis values are now deprecated and will raise an AxisError in the - future. - Parameters ---------- a : array_like Input array. - axis : int - Position in the expanded axes where the new axis is placed. + axis : int or tuple of ints + Position in the expanded axes where the new axis (or axes) is placed. + + .. deprecated:: 1.13.0 + Passing an axis where ``axis > a.ndim`` will be treated as + ``axis == a.ndim``, and passing ``axis < -a.ndim - 1`` will + be treated as ``axis == 0``. This behavior is deprecated. + + .. versionchanged:: 1.18.0 + A tuple of axes is now supported. Out of range axes as + described above are now forbidden and raise an `AxisError`. Returns ------- - res : ndarray - View of `a` with the number of dimensions increased by one. + result : ndarray + View of `a` with the number of dimensions increased. See Also -------- @@ -542,11 +546,11 @@ def expand_dims(a, axis): Examples -------- - >>> x = np.array([1,2]) + >>> x = np.array([1, 2]) >>> x.shape (2,) - The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``: + The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``: >>> y = np.expand_dims(x, axis=0) >>> y @@ -554,13 +558,26 @@ def expand_dims(a, axis): >>> y.shape (1, 2) - >>> y = np.expand_dims(x, axis=1) # Equivalent to x[:,np.newaxis] + The following is equivalent to ``x[:, np.newaxis]``: + + >>> y = np.expand_dims(x, axis=1) >>> y array([[1], [2]]) >>> y.shape (2, 1) + ``axis`` may also be a tuple: + + >>> y = np.expand_dims(x, axis=(0, 1)) + >>> y + array([[[1, 2]]]) + + >>> y = np.expand_dims(x, axis=(2, 0)) + >>> y + array([[[1], + [2]]]) + Note that some examples may use ``None`` instead of ``np.newaxis``. These are the same objects: @@ -573,18 +590,16 @@ def expand_dims(a, axis): else: a = asanyarray(a) - shape = a.shape - if axis > a.ndim or axis < -a.ndim - 1: - # 2017-05-17, 1.13.0 - warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are " - "deprecated and will raise an AxisError in the future.", - DeprecationWarning, stacklevel=3) - # When the deprecation period expires, delete this if block, - if axis < 0: - axis = axis + a.ndim + 1 - # and uncomment the following line. - # axis = normalize_axis_index(axis, a.ndim + 1) - return a.reshape(shape[:axis] + (1,) + shape[axis:]) + if type(axis) not in (tuple, list): + axis = (axis,) + + out_ndim = len(axis) + a.ndim + axis = normalize_axis_tuple(axis, out_ndim) + + shape_it = iter(a.shape) + shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] + + return a.reshape(shape) row_stack = vstack diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 01ea028bb..be1604a75 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -289,14 +289,26 @@ class TestExpandDims(object): assert_(b.shape[axis] == 1) assert_(np.squeeze(b).shape == s) - def test_deprecations(self): - # 2017-05-17, 1.13.0 + def test_axis_tuple(self): + a = np.empty((3, 3, 3)) + assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) + assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) + assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) + assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) + + def test_axis_out_of_range(self): s = (2, 3, 4, 5) a = np.empty(s) - with warnings.catch_warnings(): - warnings.simplefilter("always") - assert_warns(DeprecationWarning, expand_dims, a, -6) - assert_warns(DeprecationWarning, expand_dims, a, 5) + assert_raises(np.AxisError, expand_dims, a, -6) + assert_raises(np.AxisError, expand_dims, a, 5) + + a = np.empty((3, 3, 3)) + assert_raises(np.AxisError, expand_dims, a, (0, -6)) + assert_raises(np.AxisError, expand_dims, a, (0, 5)) + + def test_repeated_axis(self): + a = np.empty((3, 3, 3)) + assert_raises(ValueError, expand_dims, a, axis=(1, 1)) def test_subclasses(self): a = np.arange(10).reshape((2, 5)) |