summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/tests/test_extras.py')
-rw-r--r--numpy/ma/tests/test_extras.py40
1 files changed, 39 insertions, 1 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 01a47bef8..05403344b 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -28,7 +28,7 @@ from numpy.ma.extras import (
ediff1d, apply_over_axes, apply_along_axis, compress_nd, compress_rowcols,
mask_rowcols, clump_masked, clump_unmasked, flatnotmasked_contiguous,
notmasked_contiguous, notmasked_edges, masked_all, masked_all_like, isin,
- diagflat, stack, vstack
+ diagflat, ndenumerate, stack, vstack
)
@@ -1671,6 +1671,44 @@ class TestShapeBase:
assert_equal(b.mask.shape, b.data.shape)
+class TestNDEnumerate:
+
+ def test_ndenumerate_nomasked(self):
+ ordinary = np.ndarray(6).reshape((1, 3, 2))
+ empty_mask = np.zeros_like(ordinary, dtype=bool)
+ with_mask = masked_array(ordinary, mask=empty_mask)
+ assert_equal(list(np.ndenumerate(ordinary)),
+ list(ndenumerate(ordinary)))
+ assert_equal(list(ndenumerate(ordinary)),
+ list(ndenumerate(with_mask)))
+ assert_equal(list(ndenumerate(with_mask)),
+ list(ndenumerate(with_mask, compressed=False)))
+
+ def test_ndenumerate_allmasked(self):
+ a = masked_all(())
+ b = masked_all((100,))
+ c = masked_all((2, 3, 4))
+ assert_equal(list(ndenumerate(a)), [])
+ assert_equal(list(ndenumerate(b)), [])
+ assert_equal(list(ndenumerate(b, compressed=False)),
+ list(zip(np.ndindex((100,)), 100 * [masked])))
+ assert_equal(list(ndenumerate(c)), [])
+ assert_equal(list(ndenumerate(c, compressed=False)),
+ list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked])))
+
+ def test_ndenumerate_mixedmasked(self):
+ a = masked_array(np.arange(12).reshape((3, 4)),
+ mask=[[1, 1, 1, 1],
+ [1, 1, 0, 1],
+ [0, 0, 0, 0]])
+ items = [((1, 2), 6),
+ ((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)]
+ assert_equal(list(ndenumerate(a)), items)
+ assert_equal(len(list(ndenumerate(a, compressed=False))), a.size)
+ for coordinate, value in ndenumerate(a, compressed=False):
+ assert_equal(a[coordinate], value)
+
+
class TestStack:
def test_stack_1d(self):