summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJouke Witteveen <j.witteveen@gmail.com>2021-11-28 18:34:49 +0100
committerJouke Witteveen <j.witteveen@gmail.com>2022-05-10 18:19:21 +0200
commit4f1d95aa1fa0910b631e6aea91f5c2033593c11e (patch)
tree7f6b984305793db6d49b12917e9e192f154cf4da
parentff3a9dae0f4b6e539b1170a9d334dcefe862d28f (diff)
downloadnumpy-4f1d95aa1fa0910b631e6aea91f5c2033593c11e.tar.gz
ENH: Add compressed= argument to ma.ndenumerate
-rw-r--r--doc/release/upcoming_changes/20020.new_function.rst2
-rw-r--r--numpy/ma/extras.py34
-rw-r--r--numpy/ma/extras.pyi2
-rw-r--r--numpy/ma/tests/test_extras.py9
4 files changed, 39 insertions, 8 deletions
diff --git a/doc/release/upcoming_changes/20020.new_function.rst b/doc/release/upcoming_changes/20020.new_function.rst
index 135759a98..0f310ceac 100644
--- a/doc/release/upcoming_changes/20020.new_function.rst
+++ b/doc/release/upcoming_changes/20020.new_function.rst
@@ -1,4 +1,4 @@
`ndenumerate` specialization for masked arrays
----------------------------------------------
The masked array module now provides the `numpy.ma.ndenumerate` function,
-an alternative to `numpy.ndenumerate` that skips masked values.
+an alternative to `numpy.ndenumerate` that skips masked values by default.
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 5f9459b0a..127d4fda8 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1520,18 +1520,26 @@ mr_ = mr_class()
#---- Find unmasked data ---
#####--------------------------------------------------------------------------
-def ndenumerate(a):
+def ndenumerate(a, compressed=True):
"""
Multidimensional index iterator.
- Return an iterator yielding pairs of array coordinates and values of
- elements that are not masked.
+ Return an iterator yielding pairs of array coordinates and values,
+ skipping elements that are masked. With `compressed=False`,
+ `ma.masked` is yielded as the value of masked elements. This
+ behavior differs from that of `numpy.ndenumerate`, which yields the
+ value of the underlying data array.
+
+ Notes
+ -----
.. versionadded:: 1.23.0
-
+
Parameters
----------
a : array_like
An array with (possibly) masked elements.
+ compressed : bool, optional
+ If True (default), masked elements are skipped.
See Also
--------
@@ -1560,10 +1568,24 @@ def ndenumerate(a):
(1, 1) 4
(2, 0) 6
(2, 2) 8
+
+ >>> for index, x in np.ma.ndenumerate(a, compressed=False):
+ ... print(index, x)
+ (0, 0) 0
+ (0, 1) 1
+ (0, 2) 2
+ (1, 0) --
+ (1, 1) 4
+ (1, 2) --
+ (2, 0) 6
+ (2, 1) --
+ (2, 2) 8
"""
- for it, masked in zip(np.ndenumerate(a), getmaskarray(a).flat):
- if not masked:
+ for it, mask in zip(np.ndenumerate(a), getmaskarray(a).flat):
+ if not mask:
yield it
+ elif not compressed:
+ yield it[0], masked
def flatnotmasked_edges(a):
diff --git a/numpy/ma/extras.pyi b/numpy/ma/extras.pyi
index 947c7ae99..5d3912ccc 100644
--- a/numpy/ma/extras.pyi
+++ b/numpy/ma/extras.pyi
@@ -74,7 +74,7 @@ class mr_class(MAxisConcatenator):
mr_: mr_class
-def ndenumerate(a): ...
+def ndenumerate(a, compressed=...): ...
def flatnotmasked_edges(a): ...
def notmasked_edges(a, axis=...): ...
def flatnotmasked_contiguous(a): ...
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 5ea48ee51..74f744e09 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -1658,6 +1658,8 @@ class TestNDEnumerate:
list(ndenumerate(ordinary)))
assert_equal(list(ndenumerate(ordinary)),
list(ndenumerate(with_mask)))
+ assert_equal(list(ndenumerate(with_mask)),
+ list(ndenumerate(with_mask, compressed=False)))
def test_ndenumerate_allmasked(self):
a = masked_all(())
@@ -1665,7 +1667,11 @@ class TestNDEnumerate:
c = masked_all((2, 3, 4))
assert_equal(list(ndenumerate(a)), [])
assert_equal(list(ndenumerate(b)), [])
+ assert_equal(list(ndenumerate(b, compressed=False)),
+ list(zip(np.ndindex((100,)), 100 * [masked])))
assert_equal(list(ndenumerate(c)), [])
+ assert_equal(list(ndenumerate(c, compressed=False)),
+ list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked])))
def test_ndenumerate_mixedmasked(self):
a = masked_array(np.arange(12).reshape((3, 4)),
@@ -1675,6 +1681,9 @@ class TestNDEnumerate:
items = [((1, 2), 6),
((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)]
assert_equal(list(ndenumerate(a)), items)
+ assert_equal(len(list(ndenumerate(a, compressed=False))), a.size)
+ for coordinate, value in ndenumerate(a, compressed=False):
+ assert_equal(a[coordinate], value)
class TestStack: