summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@climate.com>2015-02-25 01:49:26 -0800
committerStephan Hoyer <shoyer@climate.com>2015-05-11 21:18:24 -0700
commit93d3b8dedc5cd602c867a234f07188fe5bd5479b (patch)
treecd79af4bf4e90af702d724aeaa51c1484741219c
parent2e016ac65aceab4e08217794d6be7b365793976a (diff)
downloadnumpy-93d3b8dedc5cd602c867a234f07188fe5bd5479b.tar.gz
ENH: add np.stack
The motivation here is to present a uniform and N-dimensional interface for joining arrays along a new axis, similarly to how `concatenate` provides a uniform and N-dimensional interface for joining arrays along an existing axis. Background ~~~~~~~~~~ Currently, users can choose between `hstack`, `vstack`, `column_stack` and `dstack`, but none of these functions handle N-dimensional input. In my opinion, it's also difficult to keep track of the differences between these methods and to predict how they will handle input with different dimensions. In the past, my preferred approach has been to either construct the result array explicitly and use indexing for assignment, to or use `np.array` to stack along the first dimension and then use `transpose` (or a similar method) to reorder dimensions if necessary. This is pretty awkward. I brought this proposal up a few weeks on the numpy-discussion list: http://mail.scipy.org/pipermail/numpy-discussion/2015-February/072199.html I also received positive feedback on Twitter: https://twitter.com/shoyer/status/565937244599377920 Implementation notes ~~~~~~~~~~~~~~~~~~~~ The one line summaries for `concatenate` and `stack` have been (re)written to mirror each other, and to make clear that the distinction between these functions is whether they join over an existing or new axis. In general, I've tweaked the documentation and docstrings with an eye toward pointing users to `concatenate`/`stack`/`split` as a fundamental set of basic array manipulation routines, and away from `array_split`/`{h,v,d}split`/`{h,v,d,column_}stack` I put this implementation in `numpy.core.shape_base` alongside `hstack`/`vstack`, but it appears that there is also a `numpy.lib.shape_base` module that contains another larger set of functions, including `dstack`. I'm not really sure where this belongs (or if it even matters). Finally, it might be a good idea to write a masked array version of `stack`. But I don't use masked arrays, so I'm not well motivated to do that.
-rw-r--r--doc/release/1.10.0-notes.rst3
-rw-r--r--doc/source/reference/routines.array-manipulation.rst5
-rw-r--r--numpy/add_newdocs.py3
-rw-r--r--numpy/core/shape_base.py79
-rw-r--r--numpy/core/tests/test_shape_base.py55
-rw-r--r--numpy/lib/function_base.py2
-rw-r--r--numpy/lib/index_tricks.py2
-rw-r--r--numpy/lib/shape_base.py6
8 files changed, 143 insertions, 12 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst
index f9d202ec3..c29aced4a 100644
--- a/doc/release/1.10.0-notes.rst
+++ b/doc/release/1.10.0-notes.rst
@@ -11,6 +11,9 @@ Highlights
* Addition of *np.linalg.multi_dot*: compute the dot product of two or more
arrays in a single function call, while automatically selecting the fastest
evaluation order.
+* The new function `np.stack` provides a general interface for joining a
+ sequence of arrays along a new axis, complementing `np.concatenate` for
+ joining along an existing axis.
* Addition of `nanprod` to the set of nanfunctions.
diff --git a/doc/source/reference/routines.array-manipulation.rst b/doc/source/reference/routines.array-manipulation.rst
index 81af0a315..37b82d7db 100644
--- a/doc/source/reference/routines.array-manipulation.rst
+++ b/doc/source/reference/routines.array-manipulation.rst
@@ -64,8 +64,9 @@ Joining arrays
.. autosummary::
:toctree: generated/
- column_stack
concatenate
+ stack
+ column_stack
dstack
hstack
vstack
@@ -75,10 +76,10 @@ Splitting arrays
.. autosummary::
:toctree: generated/
+ split
array_split
dsplit
hsplit
- split
vsplit
Tiling arrays
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 7dd8c5649..0333dd5a4 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -1142,7 +1142,7 @@ add_newdoc('numpy.core.multiarray', 'concatenate',
"""
concatenate((a1, a2, ...), axis=0)
- Join a sequence of arrays together.
+ Join a sequence of arrays along an existing axis.
Parameters
----------
@@ -1166,6 +1166,7 @@ add_newdoc('numpy.core.multiarray', 'concatenate',
hsplit : Split array into multiple sub-arrays horizontally (column wise)
vsplit : Split array into multiple sub-arrays vertically (row wise)
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
+ stack : Stack a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise)
vstack : Stack arrays in sequence vertically (row wise)
dstack : Stack arrays in sequence depth wise (along third dimension)
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index ae684fb42..3259f3b1d 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function
-__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack']
+__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack',
+ 'stack']
from . import numeric as _nx
from .numeric import array, asanyarray, newaxis
@@ -196,9 +197,10 @@ def vstack(tup):
See Also
--------
+ stack : Join a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise).
dstack : Stack arrays in sequence depth wise (along third dimension).
- concatenate : Join a sequence of arrays together.
+ concatenate : Join a sequence of arrays along an existing axis.
vsplit : Split array into a list of multiple sub-arrays vertically.
Notes
@@ -246,9 +248,10 @@ def hstack(tup):
See Also
--------
+ stack : Join a sequence of arrays along a new axis.
vstack : Stack arrays in sequence vertically (row wise).
dstack : Stack arrays in sequence depth wise (along third axis).
- concatenate : Join a sequence of arrays together.
+ concatenate : Join a sequence of arrays along an existing axis.
hsplit : Split array along second axis.
Notes
@@ -275,3 +278,73 @@ def hstack(tup):
return _nx.concatenate(arrs, 0)
else:
return _nx.concatenate(arrs, 1)
+
+def stack(arrays, axis=0):
+ """
+ Join a sequence of arrays along a new axis.
+
+ The `axis` parameter specifies the index of the new axis in the dimensions
+ of the result. For example, if ``axis=0`` it will be the first dimension
+ and if ``axis=-1`` it will be the last dimension.
+
+ .. versionadded:: 1.10.0
+
+ Parameters
+ ----------
+ arrays : sequence of array_like
+ Each array must have the same shape.
+ axis : int, optional
+ The axis in the result array along which the input arrays are stacked.
+
+ Returns
+ -------
+ stacked : ndarray
+ The stacked array has one more dimension than the input arrays.
+
+ See Also
+ --------
+ concatenate : Join a sequence of arrays along an existing axis.
+ split : Split array into a list of multiple sub-arrays of equal size.
+
+ Examples
+ --------
+ >>> arrays = [np.random.randn(3, 4) for _ in range(10)]
+ >>> np.stack(arrays, axis=0).shape
+ (10, 3, 4)
+
+ >>> np.stack(arrays, axis=1).shape
+ (3, 10, 4)
+
+ >>> np.stack(arrays, axis=2).shape
+ (3, 4, 10)
+
+ >>> a = np.array([1, 2, 3])
+ >>> b = np.array([2, 3, 4])
+ >>> np.stack((a, b))
+ array([[1, 2, 3],
+ [2, 3, 4]])
+
+ >>> np.stack((a, b), axis=-1)
+ array([[1, 2],
+ [2, 3],
+ [3, 4]])
+
+ """
+ arrays = [asanyarray(arr) for arr in arrays]
+ if not arrays:
+ raise ValueError('need at least one array to stack')
+
+ shapes = set(arr.shape for arr in arrays)
+ if len(shapes) != 1:
+ raise ValueError('all input arrays must have the same shape')
+
+ result_ndim = arrays[0].ndim + 1
+ if not -result_ndim <= axis < result_ndim:
+ msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
+ raise IndexError(msg)
+ if axis < 0:
+ axis += result_ndim
+
+ sl = (slice(None),) * axis + (_nx.newaxis,)
+ expanded_arrays = [arr[sl] for arr in arrays]
+ return _nx.concatenate(expanded_arrays, axis=axis)
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index f1f5311c9..c6399bb07 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -3,9 +3,9 @@ from __future__ import division, absolute_import, print_function
import warnings
import numpy as np
from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
- assert_equal, run_module_suite)
+ assert_equal, run_module_suite, assert_raises_regex)
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
- vstack, hstack, newaxis, concatenate)
+ vstack, hstack, newaxis, concatenate, stack)
from numpy.compat import long
class TestAtleast1d(TestCase):
@@ -246,5 +246,56 @@ def test_concatenate_sloppy0():
assert_raises(DeprecationWarning, concatenate, (r4, r3), 10)
+def test_stack():
+ # 0d input
+ for input_ in [(1, 2, 3),
+ [np.int32(1), np.int32(2), np.int32(3)],
+ [np.array(1), np.array(2), np.array(3)]]:
+ assert_array_equal(stack(input_), [1, 2, 3])
+ # 1d input examples
+ a = np.array([1, 2, 3])
+ b = np.array([4, 5, 6])
+ r1 = array([[1, 2, 3], [4, 5, 6]])
+ assert_array_equal(np.stack((a, b)), r1)
+ assert_array_equal(np.stack((a, b), axis=1), r1.T)
+ # all input types
+ assert_array_equal(np.stack(list([a, b])), r1)
+ assert_array_equal(np.stack(array([a, b])), r1)
+ # all shapes for 1d input
+ arrays = [np.random.randn(3) for _ in range(10)]
+ axes = [0, 1, -1, -2]
+ expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)]
+ for axis, expected_shape in zip(axes, expected_shapes):
+ assert_equal(np.stack(arrays, axis).shape, expected_shape)
+ assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2)
+ assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3)
+ # all shapes for 2d input
+ arrays = [np.random.randn(3, 4) for _ in range(10)]
+ axes = [0, 1, 2, -1, -2, -3]
+ expected_shapes = [(10, 3, 4), (3, 10, 4), (3, 4, 10),
+ (3, 4, 10), (3, 10, 4), (10, 3, 4)]
+ for axis, expected_shape in zip(axes, expected_shapes):
+ assert_equal(np.stack(arrays, axis).shape, expected_shape)
+ # empty arrays
+ assert stack([[], [], []]).shape == (3, 0)
+ assert stack([[], [], []], axis=1).shape == (0, 3)
+ # edge cases
+ assert_raises_regex(ValueError, 'need at least one array', stack, [])
+ assert_raises_regex(ValueError, 'must have the same shape',
+ stack, [1, np.arange(3)])
+ assert_raises_regex(ValueError, 'must have the same shape',
+ stack, [np.arange(3), 1])
+ assert_raises_regex(ValueError, 'must have the same shape',
+ stack, [np.arange(3), 1], axis=1)
+ assert_raises_regex(ValueError, 'must have the same shape',
+ stack, [np.zeros((3, 3)), np.zeros(3)], axis=1)
+ assert_raises_regex(ValueError, 'must have the same shape',
+ stack, [np.arange(2), np.arange(3)])
+ # np.matrix
+ m = np.matrix([[1, 2], [3, 4]])
+ assert_raises_regex(ValueError, 'shape too large to be a matrix',
+ stack, [m, m])
+
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 2baf83830..5e1128761 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3706,7 +3706,7 @@ def insert(arr, obj, values, axis=None):
See Also
--------
append : Append elements at the end of an array.
- concatenate : Join a sequence of arrays together.
+ concatenate : Join a sequence of arrays along an existing axis.
delete : Delete elements from an array.
Notes
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index eb9aad6ad..113e10d90 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -404,7 +404,7 @@ class RClass(AxisConcatenator):
See Also
--------
- concatenate : Join a sequence of arrays together.
+ concatenate : Join a sequence of arrays along an existing axis.
c_ : Translates slice objects to concatenation along the second axis.
Examples
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 2d18c5bc8..47ac07eea 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -338,9 +338,10 @@ def dstack(tup):
See Also
--------
+ stack : Join a sequence of arrays along a new axis.
vstack : Stack along first axis.
hstack : Stack along second axis.
- concatenate : Join arrays.
+ concatenate : Join a sequence of arrays along an existing axis.
dsplit : Split array along third axis.
Notes
@@ -477,7 +478,8 @@ def split(ary,indices_or_sections,axis=0):
hsplit : Split array into multiple sub-arrays horizontally (column-wise).
vsplit : Split array into multiple sub-arrays vertically (row wise).
dsplit : Split array into multiple sub-arrays along the 3rd axis (depth).
- concatenate : Join arrays together.
+ concatenate : Join a sequence of arrays along an existing axis.
+ stack : Join a sequence of arrays along a new axis.
hstack : Stack arrays in sequence horizontally (column wise).
vstack : Stack arrays in sequence vertically (row wise).
dstack : Stack arrays in sequence depth wise (along third dimension).