diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-10-26 09:35:23 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-26 09:35:23 -0500 |
commit | 3debe9772ea1b68d997dba3440929a467ad11c52 (patch) | |
tree | e9eaef760e19f86622cafa19ff64a0a46a60fbc3 | |
parent | 763be37f77b572f46637db4032cebfd0f96175d9 (diff) | |
parent | 173f65c02c8d654fc55f381665ec43c64d43d636 (diff) | |
download | numpy-3debe9772ea1b68d997dba3440929a467ad11c52.tar.gz |
Merge pull request #12257 from charris/fix-ma-compare-fill_value
BUG: Fix fill value in masked array '==' and '!=' ops.
-rw-r--r-- | numpy/ma/core.py | 23 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 150 |
2 files changed, 163 insertions, 10 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 9259aa4c7..9ee44e9ff 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -446,6 +446,7 @@ def _check_fill_value(fill_value, ndtype): If fill_value is not None, its value is forced to the given dtype. The result is always a 0d array. + """ ndtype = np.dtype(ndtype) fields = ndtype.fields @@ -465,17 +466,19 @@ def _check_fill_value(fill_value, ndtype): dtype=ndtype) else: if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'): + # Note this check doesn't work if fill_value is not a scalar err_msg = "Cannot set fill value of string with array of dtype %s" raise TypeError(err_msg % ndtype) else: # In case we want to convert 1e20 to int. + # Also in case of converting string arrays. try: fill_value = np.array(fill_value, copy=False, dtype=ndtype) - except OverflowError: - # Raise TypeError instead of OverflowError. OverflowError - # is seldom used, and the real problem here is that the - # passed fill_value is not compatible with the ndtype. - err_msg = "Fill value %s overflows dtype %s" + except (OverflowError, ValueError): + # Raise TypeError instead of OverflowError or ValueError. + # OverflowError is seldom used, and the real problem here is + # that the passed fill_value is not compatible with the ndtype. + err_msg = "Cannot convert fill_value %s to dtype %s" raise TypeError(err_msg % (fill_value, ndtype)) return np.array(fill_value) @@ -4014,6 +4017,16 @@ class MaskedArray(ndarray): check = check.view(type(self)) check._update_from(self) check._mask = mask + + # Cast fill value to bool_ if needed. If it cannot be cast, the + # default boolean fill value is used. + if check._fill_value is not None: + try: + fill = _check_fill_value(check._fill_value, np.bool_) + except (TypeError, ValueError): + fill = _check_fill_value(None, np.bool_) + check._fill_value = fill + return check def __eq__(self, other): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 50ab6e1de..8a015e609 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -60,6 +60,11 @@ suppress_copy_mask_on_assignment.filter( "setting an item on a masked array which has a shared mask will not copy") +# For parametrized numeric testing +num_dts = [np.dtype(dt_) for dt_ in '?bhilqBHILQefdgFD'] +num_ids = [dt_.char for dt_ in num_dts] + + class TestMaskedArray(object): # Base test class for MaskedArrays. @@ -1415,23 +1420,34 @@ class TestMaskedArrayArithmetic(object): # Test the equality of structured arrays ndtype = [('A', int), ('B', int)] a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + test = (a == a) assert_equal(test.data, [True, True]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + test = (a == a[0]) assert_equal(test.data, [True, False]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) test = (a == b) assert_equal(test.data, [False, True]) assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + test = (a[0] == b) assert_equal(test.data, [False, False]) assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) test = (a == b) assert_equal(test.data, [True, True]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + # complicated dtype, 2-dimensional array. ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] a = array([[(1, (1, 1)), (2, (2, 2))], @@ -1441,28 +1457,40 @@ class TestMaskedArrayArithmetic(object): test = (a[0, 0] == a) assert_equal(test.data, [[True, False], [False, False]]) assert_equal(test.mask, [[False, False], [False, True]]) + assert_(test.fill_value == True) def test_ne_on_structured(self): # Test the equality of structured arrays ndtype = [('A', int), ('B', int)] a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype) + test = (a != a) assert_equal(test.data, [False, False]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + test = (a != a[0]) assert_equal(test.data, [False, True]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype) test = (a != b) assert_equal(test.data, [True, False]) assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + test = (a[0] != b) assert_equal(test.data, [True, True]) assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype) test = (a != b) assert_equal(test.data, [False, False]) assert_equal(test.mask, [False, False]) + assert_(test.fill_value == True) + # complicated dtype, 2-dimensional array. ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])] a = array([[(1, (1, 1)), (2, (2, 2))], @@ -1472,6 +1500,7 @@ class TestMaskedArrayArithmetic(object): test = (a[0, 0] != a) assert_equal(test.data, [[False, True], [True, True]]) assert_equal(test.mask, [[False, False], [False, True]]) + assert_(test.fill_value == True) def test_eq_ne_structured_extra(self): # ensure simple examples are symmetric and make sense. @@ -1507,6 +1536,120 @@ class TestMaskedArrayArithmetic(object): el_by_el = [m1[name] != m2[name] for name in dt.names] assert_equal(array(el_by_el, dtype=bool).any(), ne_expected) + @pytest.mark.parametrize('dt', ['S', 'U']) + @pytest.mark.parametrize('fill', [None, 'A']) + def test_eq_for_strings(self, dt, fill): + # Test the equality of structured arrays + a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill) + + test = (a == a) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a == a[0]) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill) + test = (a == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + # test = (a[0] == b) # doesn't work in Python2 + test = (b == a[0]) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt', ['S', 'U']) + @pytest.mark.parametrize('fill', [None, 'A']) + def test_ne_for_strings(self, dt, fill): + # Test the equality of structured arrays + a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill) + + test = (a != a) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a != a[0]) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill) + test = (a != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + # test = (a[0] != b) # doesn't work in Python2 + test = (b != a[0]) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) + @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) + @pytest.mark.parametrize('fill', [None, 1]) + def test_eq_for_numeric(self, dt1, dt2, fill): + # Test the equality of structured arrays + a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill) + + test = (a == a) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a == a[0]) + assert_equal(test.data, [True, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill) + test = (a == b) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + # test = (a[0] == b) # doesn't work in Python2 + test = (b == a[0]) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + + @pytest.mark.parametrize('dt1', num_dts, ids=num_ids) + @pytest.mark.parametrize('dt2', num_dts, ids=num_ids) + @pytest.mark.parametrize('fill', [None, 1]) + def test_ne_for_numeric(self, dt1, dt2, fill): + # Test the equality of structured arrays + a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill) + + test = (a != a) + assert_equal(test.data, [False, False]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + test = (a != a[0]) + assert_equal(test.data, [False, True]) + assert_equal(test.mask, [False, True]) + assert_(test.fill_value == True) + + b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill) + test = (a != b) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, True]) + assert_(test.fill_value == True) + + # test = (a[0] != b) # doesn't work in Python2 + test = (b != a[0]) + assert_equal(test.data, [True, True]) + assert_equal(test.mask, [True, False]) + assert_(test.fill_value == True) + def test_eq_with_None(self): # Really, comparisons with None should not be done, but check them # anyway. Note that pep8 will flag these tests. @@ -5017,11 +5160,8 @@ def test_astype_mask_ordering(): assert_(x_f2.mask.flags.f_contiguous) -dts = [np.dtype(dt_) for dt_ in '?bhilqBHILQefdgFD'] -ids = [dt_.char for dt_ in dts] - -@pytest.mark.parametrize('dt1', dts, ids=ids) -@pytest.mark.parametrize('dt2', dts, ids=ids) +@pytest.mark.parametrize('dt1', num_dts, ids=num_ids) +@pytest.mark.parametrize('dt2', num_dts, ids=num_ids) @pytest.mark.filterwarnings('ignore::numpy.ComplexWarning') def test_astype_basic(dt1, dt2): # See gh-12070 |