diff options
author | Stephan Hoyer <shoyer@climate.com> | 2015-02-25 01:49:26 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@climate.com> | 2015-05-11 21:18:24 -0700 |
commit | 93d3b8dedc5cd602c867a234f07188fe5bd5479b (patch) | |
tree | cd79af4bf4e90af702d724aeaa51c1484741219c | |
parent | 2e016ac65aceab4e08217794d6be7b365793976a (diff) | |
download | numpy-93d3b8dedc5cd602c867a234f07188fe5bd5479b.tar.gz |
ENH: add np.stack
The motivation here is to present a uniform and N-dimensional interface for
joining arrays along a new axis, similarly to how `concatenate` provides a
uniform and N-dimensional interface for joining arrays along an existing axis.
Background
~~~~~~~~~~
Currently, users can choose between `hstack`, `vstack`, `column_stack` and
`dstack`, but none of these functions handle N-dimensional input. In my
opinion, it's also difficult to keep track of the differences between these
methods and to predict how they will handle input with different
dimensions.
In the past, my preferred approach has been to either construct the result
array explicitly and use indexing for assignment, to or use `np.array` to
stack along the first dimension and then use `transpose` (or a similar method)
to reorder dimensions if necessary. This is pretty awkward.
I brought this proposal up a few weeks on the numpy-discussion list:
http://mail.scipy.org/pipermail/numpy-discussion/2015-February/072199.html
I also received positive feedback on Twitter:
https://twitter.com/shoyer/status/565937244599377920
Implementation notes
~~~~~~~~~~~~~~~~~~~~
The one line summaries for `concatenate` and `stack` have been (re)written to
mirror each other, and to make clear that the distinction between these functions
is whether they join over an existing or new axis.
In general, I've tweaked the documentation and docstrings with an eye toward
pointing users to `concatenate`/`stack`/`split` as a fundamental set of basic
array manipulation routines, and away from
`array_split`/`{h,v,d}split`/`{h,v,d,column_}stack`
I put this implementation in `numpy.core.shape_base` alongside `hstack`/`vstack`,
but it appears that there is also a `numpy.lib.shape_base` module that contains
another larger set of functions, including `dstack`. I'm not really sure where
this belongs (or if it even matters).
Finally, it might be a good idea to write a masked array version of `stack`.
But I don't use masked arrays, so I'm not well motivated to do that.
-rw-r--r-- | doc/release/1.10.0-notes.rst | 3 | ||||
-rw-r--r-- | doc/source/reference/routines.array-manipulation.rst | 5 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 3 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 79 | ||||
-rw-r--r-- | numpy/core/tests/test_shape_base.py | 55 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 2 | ||||
-rw-r--r-- | numpy/lib/index_tricks.py | 2 | ||||
-rw-r--r-- | numpy/lib/shape_base.py | 6 |
8 files changed, 143 insertions, 12 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index f9d202ec3..c29aced4a 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -11,6 +11,9 @@ Highlights * Addition of *np.linalg.multi_dot*: compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. +* The new function `np.stack` provides a general interface for joining a + sequence of arrays along a new axis, complementing `np.concatenate` for + joining along an existing axis. * Addition of `nanprod` to the set of nanfunctions. diff --git a/doc/source/reference/routines.array-manipulation.rst b/doc/source/reference/routines.array-manipulation.rst index 81af0a315..37b82d7db 100644 --- a/doc/source/reference/routines.array-manipulation.rst +++ b/doc/source/reference/routines.array-manipulation.rst @@ -64,8 +64,9 @@ Joining arrays .. autosummary:: :toctree: generated/ - column_stack concatenate + stack + column_stack dstack hstack vstack @@ -75,10 +76,10 @@ Splitting arrays .. autosummary:: :toctree: generated/ + split array_split dsplit hsplit - split vsplit Tiling arrays diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 7dd8c5649..0333dd5a4 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -1142,7 +1142,7 @@ add_newdoc('numpy.core.multiarray', 'concatenate', """ concatenate((a1, a2, ...), axis=0) - Join a sequence of arrays together. + Join a sequence of arrays along an existing axis. Parameters ---------- @@ -1166,6 +1166,7 @@ add_newdoc('numpy.core.multiarray', 'concatenate', hsplit : Split array into multiple sub-arrays horizontally (column wise) vsplit : Split array into multiple sub-arrays vertically (row wise) dsplit : Split array into multiple sub-arrays along the 3rd axis (depth). + stack : Stack a sequence of arrays along a new axis. hstack : Stack arrays in sequence horizontally (column wise) vstack : Stack arrays in sequence vertically (row wise) dstack : Stack arrays in sequence depth wise (along third dimension) diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index ae684fb42..3259f3b1d 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -1,6 +1,7 @@ from __future__ import division, absolute_import, print_function -__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack'] +__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'vstack', 'hstack', + 'stack'] from . import numeric as _nx from .numeric import array, asanyarray, newaxis @@ -196,9 +197,10 @@ def vstack(tup): See Also -------- + stack : Join a sequence of arrays along a new axis. hstack : Stack arrays in sequence horizontally (column wise). dstack : Stack arrays in sequence depth wise (along third dimension). - concatenate : Join a sequence of arrays together. + concatenate : Join a sequence of arrays along an existing axis. vsplit : Split array into a list of multiple sub-arrays vertically. Notes @@ -246,9 +248,10 @@ def hstack(tup): See Also -------- + stack : Join a sequence of arrays along a new axis. vstack : Stack arrays in sequence vertically (row wise). dstack : Stack arrays in sequence depth wise (along third axis). - concatenate : Join a sequence of arrays together. + concatenate : Join a sequence of arrays along an existing axis. hsplit : Split array along second axis. Notes @@ -275,3 +278,73 @@ def hstack(tup): return _nx.concatenate(arrs, 0) else: return _nx.concatenate(arrs, 1) + +def stack(arrays, axis=0): + """ + Join a sequence of arrays along a new axis. + + The `axis` parameter specifies the index of the new axis in the dimensions + of the result. For example, if ``axis=0`` it will be the first dimension + and if ``axis=-1`` it will be the last dimension. + + .. versionadded:: 1.10.0 + + Parameters + ---------- + arrays : sequence of array_like + Each array must have the same shape. + axis : int, optional + The axis in the result array along which the input arrays are stacked. + + Returns + ------- + stacked : ndarray + The stacked array has one more dimension than the input arrays. + + See Also + -------- + concatenate : Join a sequence of arrays along an existing axis. + split : Split array into a list of multiple sub-arrays of equal size. + + Examples + -------- + >>> arrays = [np.random.randn(3, 4) for _ in range(10)] + >>> np.stack(arrays, axis=0).shape + (10, 3, 4) + + >>> np.stack(arrays, axis=1).shape + (3, 10, 4) + + >>> np.stack(arrays, axis=2).shape + (3, 4, 10) + + >>> a = np.array([1, 2, 3]) + >>> b = np.array([2, 3, 4]) + >>> np.stack((a, b)) + array([[1, 2, 3], + [2, 3, 4]]) + + >>> np.stack((a, b), axis=-1) + array([[1, 2], + [2, 3], + [3, 4]]) + + """ + arrays = [asanyarray(arr) for arr in arrays] + if not arrays: + raise ValueError('need at least one array to stack') + + shapes = set(arr.shape for arr in arrays) + if len(shapes) != 1: + raise ValueError('all input arrays must have the same shape') + + result_ndim = arrays[0].ndim + 1 + if not -result_ndim <= axis < result_ndim: + msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim) + raise IndexError(msg) + if axis < 0: + axis += result_ndim + + sl = (slice(None),) * axis + (_nx.newaxis,) + expanded_arrays = [arr[sl] for arr in arrays] + return _nx.concatenate(expanded_arrays, axis=axis) diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index f1f5311c9..c6399bb07 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -3,9 +3,9 @@ from __future__ import division, absolute_import, print_function import warnings import numpy as np from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal, - assert_equal, run_module_suite) + assert_equal, run_module_suite, assert_raises_regex) from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d, - vstack, hstack, newaxis, concatenate) + vstack, hstack, newaxis, concatenate, stack) from numpy.compat import long class TestAtleast1d(TestCase): @@ -246,5 +246,56 @@ def test_concatenate_sloppy0(): assert_raises(DeprecationWarning, concatenate, (r4, r3), 10) +def test_stack(): + # 0d input + for input_ in [(1, 2, 3), + [np.int32(1), np.int32(2), np.int32(3)], + [np.array(1), np.array(2), np.array(3)]]: + assert_array_equal(stack(input_), [1, 2, 3]) + # 1d input examples + a = np.array([1, 2, 3]) + b = np.array([4, 5, 6]) + r1 = array([[1, 2, 3], [4, 5, 6]]) + assert_array_equal(np.stack((a, b)), r1) + assert_array_equal(np.stack((a, b), axis=1), r1.T) + # all input types + assert_array_equal(np.stack(list([a, b])), r1) + assert_array_equal(np.stack(array([a, b])), r1) + # all shapes for 1d input + arrays = [np.random.randn(3) for _ in range(10)] + axes = [0, 1, -1, -2] + expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)] + for axis, expected_shape in zip(axes, expected_shapes): + assert_equal(np.stack(arrays, axis).shape, expected_shape) + assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2) + assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3) + # all shapes for 2d input + arrays = [np.random.randn(3, 4) for _ in range(10)] + axes = [0, 1, 2, -1, -2, -3] + expected_shapes = [(10, 3, 4), (3, 10, 4), (3, 4, 10), + (3, 4, 10), (3, 10, 4), (10, 3, 4)] + for axis, expected_shape in zip(axes, expected_shapes): + assert_equal(np.stack(arrays, axis).shape, expected_shape) + # empty arrays + assert stack([[], [], []]).shape == (3, 0) + assert stack([[], [], []], axis=1).shape == (0, 3) + # edge cases + assert_raises_regex(ValueError, 'need at least one array', stack, []) + assert_raises_regex(ValueError, 'must have the same shape', + stack, [1, np.arange(3)]) + assert_raises_regex(ValueError, 'must have the same shape', + stack, [np.arange(3), 1]) + assert_raises_regex(ValueError, 'must have the same shape', + stack, [np.arange(3), 1], axis=1) + assert_raises_regex(ValueError, 'must have the same shape', + stack, [np.zeros((3, 3)), np.zeros(3)], axis=1) + assert_raises_regex(ValueError, 'must have the same shape', + stack, [np.arange(2), np.arange(3)]) + # np.matrix + m = np.matrix([[1, 2], [3, 4]]) + assert_raises_regex(ValueError, 'shape too large to be a matrix', + stack, [m, m]) + + if __name__ == "__main__": run_module_suite() diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 2baf83830..5e1128761 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3706,7 +3706,7 @@ def insert(arr, obj, values, axis=None): See Also -------- append : Append elements at the end of an array. - concatenate : Join a sequence of arrays together. + concatenate : Join a sequence of arrays along an existing axis. delete : Delete elements from an array. Notes diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index eb9aad6ad..113e10d90 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -404,7 +404,7 @@ class RClass(AxisConcatenator): See Also -------- - concatenate : Join a sequence of arrays together. + concatenate : Join a sequence of arrays along an existing axis. c_ : Translates slice objects to concatenation along the second axis. Examples diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 2d18c5bc8..47ac07eea 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -338,9 +338,10 @@ def dstack(tup): See Also -------- + stack : Join a sequence of arrays along a new axis. vstack : Stack along first axis. hstack : Stack along second axis. - concatenate : Join arrays. + concatenate : Join a sequence of arrays along an existing axis. dsplit : Split array along third axis. Notes @@ -477,7 +478,8 @@ def split(ary,indices_or_sections,axis=0): hsplit : Split array into multiple sub-arrays horizontally (column-wise). vsplit : Split array into multiple sub-arrays vertically (row wise). dsplit : Split array into multiple sub-arrays along the 3rd axis (depth). - concatenate : Join arrays together. + concatenate : Join a sequence of arrays along an existing axis. + stack : Join a sequence of arrays along a new axis. hstack : Stack arrays in sequence horizontally (column wise). vstack : Stack arrays in sequence vertically (row wise). dstack : Stack arrays in sequence depth wise (along third dimension). |