diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-10-05 19:09:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-05 19:09:56 +0200 |
commit | 02b68f16efaff733e77a12c67ea888dba6827720 (patch) | |
tree | aa7e5b339ba25234b3b163642f407e06fe782026 /numpy | |
parent | 110d75ddc86ad49b1fc9207b201c93eed431d22b (diff) | |
parent | 4213779a4a1108f5908a51dd96ed468057f2c1f2 (diff) | |
download | numpy-02b68f16efaff733e77a12c67ea888dba6827720.tar.gz |
Merge pull request #22046 from cmarmo/masked-invalid
BUG: Make `mask_invalid` consistent with `mask_where` if `copy` is set to `False`
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 14 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 12 |
2 files changed, 13 insertions, 13 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 313445be7..0ee630a97 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -2356,20 +2356,8 @@ def masked_invalid(a, copy=True): fill_value=1e+20) """ - a = np.array(a, copy=copy, subok=True) - mask = getattr(a, '_mask', None) - if mask is not None: - condition = ~(np.isfinite(getdata(a))) - if mask is not nomask: - condition |= mask - cls = type(a) - else: - condition = ~(np.isfinite(a)) - cls = MaskedArray - result = a.view(cls) - result._mask = condition - return result + return masked_where(~(np.isfinite(getdata(a))), a, copy=copy) ############################################################################### # Printing options # diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 3997f44e3..b1c080a56 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4496,6 +4496,14 @@ class TestMaskedArrayFunctions: assert_equal(ma, expected) assert_equal(ma.mask, expected.mask) + def test_masked_invalid_error(self): + a = np.arange(5, dtype=object) + a[3] = np.PINF + a[2] = np.NaN + with pytest.raises(TypeError, + match="not supported for the input types"): + np.ma.masked_invalid(a) + def test_choose(self): # Test choose choices = [[0, 1, 2, 3], [10, 11, 12, 13], @@ -5309,6 +5317,10 @@ def test_masked_array_no_copy(): a = np.ma.array([1, 2, 3, 4], mask=[1, 0, 0, 0]) _ = np.ma.masked_where(a == 3, a, copy=False) assert_array_equal(a.mask, [True, False, True, False]) + # check masked array with masked_invalid is updated in place + a = np.ma.array([np.inf, 1, 2, 3, 4]) + _ = np.ma.masked_invalid(a, copy=False) + assert_array_equal(a.mask, [True, False, False, False, False]) def test_append_masked_array(): a = np.ma.masked_equal([1,2,3], value=2) |