summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-10-26 09:35:23 -0500
committerGitHub <noreply@github.com>2018-10-26 09:35:23 -0500
commit3debe9772ea1b68d997dba3440929a467ad11c52 (patch)
treee9eaef760e19f86622cafa19ff64a0a46a60fbc3 /numpy
parent763be37f77b572f46637db4032cebfd0f96175d9 (diff)
parent173f65c02c8d654fc55f381665ec43c64d43d636 (diff)
downloadnumpy-3debe9772ea1b68d997dba3440929a467ad11c52.tar.gz
Merge pull request #12257 from charris/fix-ma-compare-fill_value
BUG: Fix fill value in masked array '==' and '!=' ops.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py23
-rw-r--r--numpy/ma/tests/test_core.py150
2 files changed, 163 insertions, 10 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 9259aa4c7..9ee44e9ff 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -446,6 +446,7 @@ def _check_fill_value(fill_value, ndtype):
If fill_value is not None, its value is forced to the given dtype.
The result is always a 0d array.
+
"""
ndtype = np.dtype(ndtype)
fields = ndtype.fields
@@ -465,17 +466,19 @@ def _check_fill_value(fill_value, ndtype):
dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'):
+ # Note this check doesn't work if fill_value is not a scalar
err_msg = "Cannot set fill value of string with array of dtype %s"
raise TypeError(err_msg % ndtype)
else:
# In case we want to convert 1e20 to int.
+ # Also in case of converting string arrays.
try:
fill_value = np.array(fill_value, copy=False, dtype=ndtype)
- except OverflowError:
- # Raise TypeError instead of OverflowError. OverflowError
- # is seldom used, and the real problem here is that the
- # passed fill_value is not compatible with the ndtype.
- err_msg = "Fill value %s overflows dtype %s"
+ except (OverflowError, ValueError):
+ # Raise TypeError instead of OverflowError or ValueError.
+ # OverflowError is seldom used, and the real problem here is
+ # that the passed fill_value is not compatible with the ndtype.
+ err_msg = "Cannot convert fill_value %s to dtype %s"
raise TypeError(err_msg % (fill_value, ndtype))
return np.array(fill_value)
@@ -4014,6 +4017,16 @@ class MaskedArray(ndarray):
check = check.view(type(self))
check._update_from(self)
check._mask = mask
+
+ # Cast fill value to bool_ if needed. If it cannot be cast, the
+ # default boolean fill value is used.
+ if check._fill_value is not None:
+ try:
+ fill = _check_fill_value(check._fill_value, np.bool_)
+ except (TypeError, ValueError):
+ fill = _check_fill_value(None, np.bool_)
+ check._fill_value = fill
+
return check
def __eq__(self, other):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 50ab6e1de..8a015e609 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -60,6 +60,11 @@ suppress_copy_mask_on_assignment.filter(
"setting an item on a masked array which has a shared mask will not copy")
+# For parametrized numeric testing
+num_dts = [np.dtype(dt_) for dt_ in '?bhilqBHILQefdgFD']
+num_ids = [dt_.char for dt_ in num_dts]
+
+
class TestMaskedArray(object):
# Base test class for MaskedArrays.
@@ -1415,23 +1420,34 @@ class TestMaskedArrayArithmetic(object):
# Test the equality of structured arrays
ndtype = [('A', int), ('B', int)]
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+
test = (a == a)
assert_equal(test.data, [True, True])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
test = (a == a[0])
assert_equal(test.data, [True, False])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
test = (a == b)
assert_equal(test.data, [False, True])
assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
test = (a[0] == b)
assert_equal(test.data, [False, False])
assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
test = (a == b)
assert_equal(test.data, [True, True])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
# complicated dtype, 2-dimensional array.
ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
a = array([[(1, (1, 1)), (2, (2, 2))],
@@ -1441,28 +1457,40 @@ class TestMaskedArrayArithmetic(object):
test = (a[0, 0] == a)
assert_equal(test.data, [[True, False], [False, False]])
assert_equal(test.mask, [[False, False], [False, True]])
+ assert_(test.fill_value == True)
def test_ne_on_structured(self):
# Test the equality of structured arrays
ndtype = [('A', int), ('B', int)]
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
+
test = (a != a)
assert_equal(test.data, [False, False])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
test = (a != a[0])
assert_equal(test.data, [False, True])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
test = (a != b)
assert_equal(test.data, [True, False])
assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
test = (a[0] != b)
assert_equal(test.data, [True, True])
assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
test = (a != b)
assert_equal(test.data, [False, False])
assert_equal(test.mask, [False, False])
+ assert_(test.fill_value == True)
+
# complicated dtype, 2-dimensional array.
ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
a = array([[(1, (1, 1)), (2, (2, 2))],
@@ -1472,6 +1500,7 @@ class TestMaskedArrayArithmetic(object):
test = (a[0, 0] != a)
assert_equal(test.data, [[False, True], [True, True]])
assert_equal(test.mask, [[False, False], [False, True]])
+ assert_(test.fill_value == True)
def test_eq_ne_structured_extra(self):
# ensure simple examples are symmetric and make sense.
@@ -1507,6 +1536,120 @@ class TestMaskedArrayArithmetic(object):
el_by_el = [m1[name] != m2[name] for name in dt.names]
assert_equal(array(el_by_el, dtype=bool).any(), ne_expected)
+ @pytest.mark.parametrize('dt', ['S', 'U'])
+ @pytest.mark.parametrize('fill', [None, 'A'])
+ def test_eq_for_strings(self, dt, fill):
+ # Test the equality of structured arrays
+ a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill)
+
+ test = (a == a)
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ test = (a == a[0])
+ assert_equal(test.data, [True, False])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill)
+ test = (a == b)
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [True, True])
+ assert_(test.fill_value == True)
+
+ # test = (a[0] == b) # doesn't work in Python2
+ test = (b == a[0])
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
+ @pytest.mark.parametrize('dt', ['S', 'U'])
+ @pytest.mark.parametrize('fill', [None, 'A'])
+ def test_ne_for_strings(self, dt, fill):
+ # Test the equality of structured arrays
+ a = array(['a', 'b'], dtype=dt, mask=[0, 1], fill_value=fill)
+
+ test = (a != a)
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ test = (a != a[0])
+ assert_equal(test.data, [False, True])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ b = array(['a', 'b'], dtype=dt, mask=[1, 0], fill_value=fill)
+ test = (a != b)
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [True, True])
+ assert_(test.fill_value == True)
+
+ # test = (a[0] != b) # doesn't work in Python2
+ test = (b != a[0])
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
+ @pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('fill', [None, 1])
+ def test_eq_for_numeric(self, dt1, dt2, fill):
+ # Test the equality of structured arrays
+ a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill)
+
+ test = (a == a)
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ test = (a == a[0])
+ assert_equal(test.data, [True, False])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill)
+ test = (a == b)
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [True, True])
+ assert_(test.fill_value == True)
+
+ # test = (a[0] == b) # doesn't work in Python2
+ test = (b == a[0])
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
+ @pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
+ @pytest.mark.parametrize('fill', [None, 1])
+ def test_ne_for_numeric(self, dt1, dt2, fill):
+ # Test the equality of structured arrays
+ a = array([0, 1], dtype=dt1, mask=[0, 1], fill_value=fill)
+
+ test = (a != a)
+ assert_equal(test.data, [False, False])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ test = (a != a[0])
+ assert_equal(test.data, [False, True])
+ assert_equal(test.mask, [False, True])
+ assert_(test.fill_value == True)
+
+ b = array([0, 1], dtype=dt2, mask=[1, 0], fill_value=fill)
+ test = (a != b)
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [True, True])
+ assert_(test.fill_value == True)
+
+ # test = (a[0] != b) # doesn't work in Python2
+ test = (b != a[0])
+ assert_equal(test.data, [True, True])
+ assert_equal(test.mask, [True, False])
+ assert_(test.fill_value == True)
+
def test_eq_with_None(self):
# Really, comparisons with None should not be done, but check them
# anyway. Note that pep8 will flag these tests.
@@ -5017,11 +5160,8 @@ def test_astype_mask_ordering():
assert_(x_f2.mask.flags.f_contiguous)
-dts = [np.dtype(dt_) for dt_ in '?bhilqBHILQefdgFD']
-ids = [dt_.char for dt_ in dts]
-
-@pytest.mark.parametrize('dt1', dts, ids=ids)
-@pytest.mark.parametrize('dt2', dts, ids=ids)
+@pytest.mark.parametrize('dt1', num_dts, ids=num_ids)
+@pytest.mark.parametrize('dt2', num_dts, ids=num_ids)
@pytest.mark.filterwarnings('ignore::numpy.ComplexWarning')
def test_astype_basic(dt1, dt2):
# See gh-12070