summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-12-17 16:21:17 -0800
committerGitHub <noreply@github.com>2017-12-17 16:21:17 -0800
commitceef078370176c877a75b4c22c756952efc8cb7c (patch)
tree5609d48bc958d3950782be72e4e2a48e0c8fb294 /numpy/ma
parent9d054c141b084546810486f91e763f3eb89af633 (diff)
parent4e112ef679284d931d71cf8f4b20acf0f0254cf0 (diff)
downloadnumpy-ceef078370176c877a75b4c22c756952efc8cb7c.tar.gz
Merge pull request #10223 from lzkelley/ma-stack-10127
ENH: added masked version of 'numpy.stack' with tests.
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/extras.py3
-rw-r--r--numpy/ma/tests/test_extras.py84
2 files changed, 85 insertions, 2 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 323fbce38..360d50d8a 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -19,7 +19,7 @@ __all__ = [
'hsplit', 'hstack', 'isin', 'in1d', 'intersect1d', 'mask_cols', 'mask_rowcols',
'mask_rows', 'masked_all', 'masked_all_like', 'median', 'mr_',
'notmasked_contiguous', 'notmasked_edges', 'polyfit', 'row_stack',
- 'setdiff1d', 'setxor1d', 'unique', 'union1d', 'vander', 'vstack',
+ 'setdiff1d', 'setxor1d', 'stack', 'unique', 'union1d', 'vander', 'vstack',
]
import itertools
@@ -357,6 +357,7 @@ vstack = row_stack = _fromnxfunction_seq('vstack')
hstack = _fromnxfunction_seq('hstack')
column_stack = _fromnxfunction_seq('column_stack')
dstack = _fromnxfunction_seq('dstack')
+stack = _fromnxfunction_seq('stack')
hsplit = _fromnxfunction_single('hsplit')
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index af9f42c2a..7687514fa 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -29,7 +29,7 @@ from numpy.ma.extras import (
ediff1d, apply_over_axes, apply_along_axis, compress_nd, compress_rowcols,
mask_rowcols, clump_masked, clump_unmasked, flatnotmasked_contiguous,
notmasked_contiguous, notmasked_edges, masked_all, masked_all_like, isin,
- diagflat
+ diagflat, stack, vstack, hstack
)
import numpy.ma.extras as mae
@@ -1589,5 +1589,87 @@ class TestShapeBase(object):
assert_equal(b.mask.shape, b.data.shape)
+class TestStack(object):
+
+ def test_stack_1d(self):
+ a = masked_array([0, 1, 2], mask=[0, 1, 0])
+ b = masked_array([9, 8, 7], mask=[1, 0, 0])
+
+ c = stack([a, b], axis=0)
+ assert_equal(c.shape, (2, 3))
+ assert_array_equal(a.mask, c[0].mask)
+ assert_array_equal(b.mask, c[1].mask)
+
+ d = vstack([a, b])
+ assert_array_equal(c.data, d.data)
+ assert_array_equal(c.mask, d.mask)
+
+ c = stack([a, b], axis=1)
+ assert_equal(c.shape, (3, 2))
+ assert_array_equal(a.mask, c[:, 0].mask)
+ assert_array_equal(b.mask, c[:, 1].mask)
+
+ def test_stack_masks(self):
+ a = masked_array([0, 1, 2], mask=True)
+ b = masked_array([9, 8, 7], mask=False)
+
+ c = stack([a, b], axis=0)
+ assert_equal(c.shape, (2, 3))
+ assert_array_equal(a.mask, c[0].mask)
+ assert_array_equal(b.mask, c[1].mask)
+
+ d = vstack([a, b])
+ assert_array_equal(c.data, d.data)
+ assert_array_equal(c.mask, d.mask)
+
+ c = stack([a, b], axis=1)
+ assert_equal(c.shape, (3, 2))
+ assert_array_equal(a.mask, c[:, 0].mask)
+ assert_array_equal(b.mask, c[:, 1].mask)
+
+ def test_stack_nd(self):
+ # 2D
+ shp = (3, 2)
+ d1 = np.random.randint(0, 10, shp)
+ d2 = np.random.randint(0, 10, shp)
+ m1 = np.random.randint(0, 2, shp).astype(bool)
+ m2 = np.random.randint(0, 2, shp).astype(bool)
+ a1 = masked_array(d1, mask=m1)
+ a2 = masked_array(d2, mask=m2)
+
+ c = stack([a1, a2], axis=0)
+ c_shp = (2,) + shp
+ assert_equal(c.shape, c_shp)
+ assert_array_equal(a1.mask, c[0].mask)
+ assert_array_equal(a2.mask, c[1].mask)
+
+ c = stack([a1, a2], axis=-1)
+ c_shp = shp + (2,)
+ assert_equal(c.shape, c_shp)
+ assert_array_equal(a1.mask, c[..., 0].mask)
+ assert_array_equal(a2.mask, c[..., 1].mask)
+
+ # 4D
+ shp = (3, 2, 4, 5,)
+ d1 = np.random.randint(0, 10, shp)
+ d2 = np.random.randint(0, 10, shp)
+ m1 = np.random.randint(0, 2, shp).astype(bool)
+ m2 = np.random.randint(0, 2, shp).astype(bool)
+ a1 = masked_array(d1, mask=m1)
+ a2 = masked_array(d2, mask=m2)
+
+ c = stack([a1, a2], axis=0)
+ c_shp = (2,) + shp
+ assert_equal(c.shape, c_shp)
+ assert_array_equal(a1.mask, c[0].mask)
+ assert_array_equal(a2.mask, c[1].mask)
+
+ c = stack([a1, a2], axis=-1)
+ c_shp = shp + (2,)
+ assert_equal(c.shape, c_shp)
+ assert_array_equal(a1.mask, c[..., 0].mask)
+ assert_array_equal(a2.mask, c[..., 1].mask)
+
+
if __name__ == "__main__":
run_module_suite()