diff options
author | Jouke Witteveen <j.witteveen@gmail.com> | 2021-11-28 18:34:49 +0100 |
---|---|---|
committer | Jouke Witteveen <j.witteveen@gmail.com> | 2022-05-10 18:19:21 +0200 |
commit | 4f1d95aa1fa0910b631e6aea91f5c2033593c11e (patch) | |
tree | 7f6b984305793db6d49b12917e9e192f154cf4da | |
parent | ff3a9dae0f4b6e539b1170a9d334dcefe862d28f (diff) | |
download | numpy-4f1d95aa1fa0910b631e6aea91f5c2033593c11e.tar.gz |
ENH: Add compressed= argument to ma.ndenumerate
-rw-r--r-- | doc/release/upcoming_changes/20020.new_function.rst | 2 | ||||
-rw-r--r-- | numpy/ma/extras.py | 34 | ||||
-rw-r--r-- | numpy/ma/extras.pyi | 2 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 9 |
4 files changed, 39 insertions, 8 deletions
diff --git a/doc/release/upcoming_changes/20020.new_function.rst b/doc/release/upcoming_changes/20020.new_function.rst index 135759a98..0f310ceac 100644 --- a/doc/release/upcoming_changes/20020.new_function.rst +++ b/doc/release/upcoming_changes/20020.new_function.rst @@ -1,4 +1,4 @@ `ndenumerate` specialization for masked arrays ---------------------------------------------- The masked array module now provides the `numpy.ma.ndenumerate` function, -an alternative to `numpy.ndenumerate` that skips masked values. +an alternative to `numpy.ndenumerate` that skips masked values by default. diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 5f9459b0a..127d4fda8 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1520,18 +1520,26 @@ mr_ = mr_class() #---- Find unmasked data --- #####-------------------------------------------------------------------------- -def ndenumerate(a): +def ndenumerate(a, compressed=True): """ Multidimensional index iterator. - Return an iterator yielding pairs of array coordinates and values of - elements that are not masked. + Return an iterator yielding pairs of array coordinates and values, + skipping elements that are masked. With `compressed=False`, + `ma.masked` is yielded as the value of masked elements. This + behavior differs from that of `numpy.ndenumerate`, which yields the + value of the underlying data array. + + Notes + ----- .. versionadded:: 1.23.0 - + Parameters ---------- a : array_like An array with (possibly) masked elements. + compressed : bool, optional + If True (default), masked elements are skipped. See Also -------- @@ -1560,10 +1568,24 @@ def ndenumerate(a): (1, 1) 4 (2, 0) 6 (2, 2) 8 + + >>> for index, x in np.ma.ndenumerate(a, compressed=False): + ... print(index, x) + (0, 0) 0 + (0, 1) 1 + (0, 2) 2 + (1, 0) -- + (1, 1) 4 + (1, 2) -- + (2, 0) 6 + (2, 1) -- + (2, 2) 8 """ - for it, masked in zip(np.ndenumerate(a), getmaskarray(a).flat): - if not masked: + for it, mask in zip(np.ndenumerate(a), getmaskarray(a).flat): + if not mask: yield it + elif not compressed: + yield it[0], masked def flatnotmasked_edges(a): diff --git a/numpy/ma/extras.pyi b/numpy/ma/extras.pyi index 947c7ae99..5d3912ccc 100644 --- a/numpy/ma/extras.pyi +++ b/numpy/ma/extras.pyi @@ -74,7 +74,7 @@ class mr_class(MAxisConcatenator): mr_: mr_class -def ndenumerate(a): ... +def ndenumerate(a, compressed=...): ... def flatnotmasked_edges(a): ... def notmasked_edges(a, axis=...): ... def flatnotmasked_contiguous(a): ... diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index 5ea48ee51..74f744e09 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -1658,6 +1658,8 @@ class TestNDEnumerate: list(ndenumerate(ordinary))) assert_equal(list(ndenumerate(ordinary)), list(ndenumerate(with_mask))) + assert_equal(list(ndenumerate(with_mask)), + list(ndenumerate(with_mask, compressed=False))) def test_ndenumerate_allmasked(self): a = masked_all(()) @@ -1665,7 +1667,11 @@ class TestNDEnumerate: c = masked_all((2, 3, 4)) assert_equal(list(ndenumerate(a)), []) assert_equal(list(ndenumerate(b)), []) + assert_equal(list(ndenumerate(b, compressed=False)), + list(zip(np.ndindex((100,)), 100 * [masked]))) assert_equal(list(ndenumerate(c)), []) + assert_equal(list(ndenumerate(c, compressed=False)), + list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked]))) def test_ndenumerate_mixedmasked(self): a = masked_array(np.arange(12).reshape((3, 4)), @@ -1675,6 +1681,9 @@ class TestNDEnumerate: items = [((1, 2), 6), ((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)] assert_equal(list(ndenumerate(a)), items) + assert_equal(len(list(ndenumerate(a, compressed=False))), a.size) + for coordinate, value in ndenumerate(a, compressed=False): + assert_equal(a[coordinate], value) class TestStack: |