diff options
| author | Eric Wieser <wieser.eric@gmail.com> | 2017-12-17 16:21:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-12-17 16:21:17 -0800 |
| commit | ceef078370176c877a75b4c22c756952efc8cb7c (patch) | |
| tree | 5609d48bc958d3950782be72e4e2a48e0c8fb294 /numpy/ma | |
| parent | 9d054c141b084546810486f91e763f3eb89af633 (diff) | |
| parent | 4e112ef679284d931d71cf8f4b20acf0f0254cf0 (diff) | |
| download | numpy-ceef078370176c877a75b4c22c756952efc8cb7c.tar.gz | |
Merge pull request #10223 from lzkelley/ma-stack-10127
ENH: added masked version of 'numpy.stack' with tests.
Diffstat (limited to 'numpy/ma')
| -rw-r--r-- | numpy/ma/extras.py | 3 | ||||
| -rw-r--r-- | numpy/ma/tests/test_extras.py | 84 |
2 files changed, 85 insertions, 2 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 323fbce38..360d50d8a 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -19,7 +19,7 @@ __all__ = [ 'hsplit', 'hstack', 'isin', 'in1d', 'intersect1d', 'mask_cols', 'mask_rowcols', 'mask_rows', 'masked_all', 'masked_all_like', 'median', 'mr_', 'notmasked_contiguous', 'notmasked_edges', 'polyfit', 'row_stack', - 'setdiff1d', 'setxor1d', 'unique', 'union1d', 'vander', 'vstack', + 'setdiff1d', 'setxor1d', 'stack', 'unique', 'union1d', 'vander', 'vstack', ] import itertools @@ -357,6 +357,7 @@ vstack = row_stack = _fromnxfunction_seq('vstack') hstack = _fromnxfunction_seq('hstack') column_stack = _fromnxfunction_seq('column_stack') dstack = _fromnxfunction_seq('dstack') +stack = _fromnxfunction_seq('stack') hsplit = _fromnxfunction_single('hsplit') diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index af9f42c2a..7687514fa 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -29,7 +29,7 @@ from numpy.ma.extras import ( ediff1d, apply_over_axes, apply_along_axis, compress_nd, compress_rowcols, mask_rowcols, clump_masked, clump_unmasked, flatnotmasked_contiguous, notmasked_contiguous, notmasked_edges, masked_all, masked_all_like, isin, - diagflat + diagflat, stack, vstack, hstack ) import numpy.ma.extras as mae @@ -1589,5 +1589,87 @@ class TestShapeBase(object): assert_equal(b.mask.shape, b.data.shape) +class TestStack(object): + + def test_stack_1d(self): + a = masked_array([0, 1, 2], mask=[0, 1, 0]) + b = masked_array([9, 8, 7], mask=[1, 0, 0]) + + c = stack([a, b], axis=0) + assert_equal(c.shape, (2, 3)) + assert_array_equal(a.mask, c[0].mask) + assert_array_equal(b.mask, c[1].mask) + + d = vstack([a, b]) + assert_array_equal(c.data, d.data) + assert_array_equal(c.mask, d.mask) + + c = stack([a, b], axis=1) + assert_equal(c.shape, (3, 2)) + assert_array_equal(a.mask, c[:, 0].mask) + assert_array_equal(b.mask, c[:, 1].mask) + + def test_stack_masks(self): + a = masked_array([0, 1, 2], mask=True) + b = masked_array([9, 8, 7], mask=False) + + c = stack([a, b], axis=0) + assert_equal(c.shape, (2, 3)) + assert_array_equal(a.mask, c[0].mask) + assert_array_equal(b.mask, c[1].mask) + + d = vstack([a, b]) + assert_array_equal(c.data, d.data) + assert_array_equal(c.mask, d.mask) + + c = stack([a, b], axis=1) + assert_equal(c.shape, (3, 2)) + assert_array_equal(a.mask, c[:, 0].mask) + assert_array_equal(b.mask, c[:, 1].mask) + + def test_stack_nd(self): + # 2D + shp = (3, 2) + d1 = np.random.randint(0, 10, shp) + d2 = np.random.randint(0, 10, shp) + m1 = np.random.randint(0, 2, shp).astype(bool) + m2 = np.random.randint(0, 2, shp).astype(bool) + a1 = masked_array(d1, mask=m1) + a2 = masked_array(d2, mask=m2) + + c = stack([a1, a2], axis=0) + c_shp = (2,) + shp + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[0].mask) + assert_array_equal(a2.mask, c[1].mask) + + c = stack([a1, a2], axis=-1) + c_shp = shp + (2,) + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[..., 0].mask) + assert_array_equal(a2.mask, c[..., 1].mask) + + # 4D + shp = (3, 2, 4, 5,) + d1 = np.random.randint(0, 10, shp) + d2 = np.random.randint(0, 10, shp) + m1 = np.random.randint(0, 2, shp).astype(bool) + m2 = np.random.randint(0, 2, shp).astype(bool) + a1 = masked_array(d1, mask=m1) + a2 = masked_array(d2, mask=m2) + + c = stack([a1, a2], axis=0) + c_shp = (2,) + shp + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[0].mask) + assert_array_equal(a2.mask, c[1].mask) + + c = stack([a1, a2], axis=-1) + c_shp = shp + (2,) + assert_equal(c.shape, c_shp) + assert_array_equal(a1.mask, c[..., 0].mask) + assert_array_equal(a2.mask, c[..., 1].mask) + + if __name__ == "__main__": run_module_suite() |
