summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorLarry Bradley <larry.bradley@gmail.com>2019-12-02 17:06:42 -0500
committerSebastian Berg <sebastian@sipsolutions.net>2019-12-02 16:06:42 -0600
commit03d489735e863e27f3e6ce39b8a85eca440c0231 (patch)
tree89bda3ea557fa7db40c0b5f9c0543e2ab93151e5 /numpy/lib
parent5992098524c9f36288093ef3298d44343735842e (diff)
downloadnumpy-03d489735e863e27f3e6ce39b8a85eca440c0231.tar.gz
ENH,DEP: Allow multiple axes in expand_dims (#14051)
This PR allows the axis keyword in expand_dims to be a tuple of ints. Previously, axis could only be an int. This issue was previously discussed in gh-12290 and the changes are based on gh-12290 (comment). This PR also removes the deprecation added in v1.13 (2017-05-17), where previously axis could be outside of the range (-a.ndim - 1) <= axis <= a.ndim. Such an axis value will now raise an AxisError. Please let me know if it's too soon to remove this deprecation (I could not find any dev docs stating the length of the numpy deprecation cycle). Closes gh-12290.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py67
-rw-r--r--numpy/lib/tests/test_shape_base.py24
2 files changed, 59 insertions, 32 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 92d52109e..dbb61c225 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,7 +1,6 @@
from __future__ import division, absolute_import, print_function
import functools
-import warnings
import numpy.core.numeric as _nx
from numpy.core.numeric import (
@@ -11,6 +10,7 @@ from numpy.core.fromnumeric import reshape, transpose
from numpy.core.multiarray import normalize_axis_index
from numpy.core import overrides
from numpy.core import vstack, atleast_3d
+from numpy.core.numeric import normalize_axis_tuple
from numpy.core.shape_base import _arrays_for_stack_dispatcher
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
@@ -29,7 +29,7 @@ array_function_dispatch = functools.partial(
def _make_along_axis_idx(arr_shape, indices, axis):
- # compute dimensions to iterate over
+ # compute dimensions to iterate over
if not _nx.issubdtype(indices.dtype, _nx.integer):
raise IndexError('`indices` must be an integer array')
if len(arr_shape) != indices.ndim:
@@ -517,22 +517,26 @@ def expand_dims(a, axis):
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
- .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
- ``axis > a.ndim`` raised errors or put the new axis where documented.
- Those axis values are now deprecated and will raise an AxisError in the
- future.
-
Parameters
----------
a : array_like
Input array.
- axis : int
- Position in the expanded axes where the new axis is placed.
+ axis : int or tuple of ints
+ Position in the expanded axes where the new axis (or axes) is placed.
+
+ .. deprecated:: 1.13.0
+ Passing an axis where ``axis > a.ndim`` will be treated as
+ ``axis == a.ndim``, and passing ``axis < -a.ndim - 1`` will
+ be treated as ``axis == 0``. This behavior is deprecated.
+
+ .. versionchanged:: 1.18.0
+ A tuple of axes is now supported. Out of range axes as
+ described above are now forbidden and raise an `AxisError`.
Returns
-------
- res : ndarray
- View of `a` with the number of dimensions increased by one.
+ result : ndarray
+ View of `a` with the number of dimensions increased.
See Also
--------
@@ -542,11 +546,11 @@ def expand_dims(a, axis):
Examples
--------
- >>> x = np.array([1,2])
+ >>> x = np.array([1, 2])
>>> x.shape
(2,)
- The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``:
+ The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
>>> y = np.expand_dims(x, axis=0)
>>> y
@@ -554,13 +558,26 @@ def expand_dims(a, axis):
>>> y.shape
(1, 2)
- >>> y = np.expand_dims(x, axis=1) # Equivalent to x[:,np.newaxis]
+ The following is equivalent to ``x[:, np.newaxis]``:
+
+ >>> y = np.expand_dims(x, axis=1)
>>> y
array([[1],
[2]])
>>> y.shape
(2, 1)
+ ``axis`` may also be a tuple:
+
+ >>> y = np.expand_dims(x, axis=(0, 1))
+ >>> y
+ array([[[1, 2]]])
+
+ >>> y = np.expand_dims(x, axis=(2, 0))
+ >>> y
+ array([[[1],
+ [2]]])
+
Note that some examples may use ``None`` instead of ``np.newaxis``. These
are the same objects:
@@ -573,18 +590,16 @@ def expand_dims(a, axis):
else:
a = asanyarray(a)
- shape = a.shape
- if axis > a.ndim or axis < -a.ndim - 1:
- # 2017-05-17, 1.13.0
- warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
- "deprecated and will raise an AxisError in the future.",
- DeprecationWarning, stacklevel=3)
- # When the deprecation period expires, delete this if block,
- if axis < 0:
- axis = axis + a.ndim + 1
- # and uncomment the following line.
- # axis = normalize_axis_index(axis, a.ndim + 1)
- return a.reshape(shape[:axis] + (1,) + shape[axis:])
+ if type(axis) not in (tuple, list):
+ axis = (axis,)
+
+ out_ndim = len(axis) + a.ndim
+ axis = normalize_axis_tuple(axis, out_ndim)
+
+ shape_it = iter(a.shape)
+ shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
+
+ return a.reshape(shape)
row_stack = vstack
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 01ea028bb..be1604a75 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -289,14 +289,26 @@ class TestExpandDims(object):
assert_(b.shape[axis] == 1)
assert_(np.squeeze(b).shape == s)
- def test_deprecations(self):
- # 2017-05-17, 1.13.0
+ def test_axis_tuple(self):
+ a = np.empty((3, 3, 3))
+ assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
+ assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1)
+ assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1)
+ assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
+
+ def test_axis_out_of_range(self):
s = (2, 3, 4, 5)
a = np.empty(s)
- with warnings.catch_warnings():
- warnings.simplefilter("always")
- assert_warns(DeprecationWarning, expand_dims, a, -6)
- assert_warns(DeprecationWarning, expand_dims, a, 5)
+ assert_raises(np.AxisError, expand_dims, a, -6)
+ assert_raises(np.AxisError, expand_dims, a, 5)
+
+ a = np.empty((3, 3, 3))
+ assert_raises(np.AxisError, expand_dims, a, (0, -6))
+ assert_raises(np.AxisError, expand_dims, a, (0, 5))
+
+ def test_repeated_axis(self):
+ a = np.empty((3, 3, 3))
+ assert_raises(ValueError, expand_dims, a, axis=(1, 1))
def test_subclasses(self):
a = np.arange(10).reshape((2, 5))