diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-05-19 12:21:47 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-05-27 08:09:38 -0400 |
commit | 5718b3306303b4899c017641a9282714dcf8c9b8 (patch) | |
tree | 4a9dfd8c1f7ee5fd3dcdbb1fd48b0a1b64a3ccc3 /numpy | |
parent | 84f582f25e168dbfd59b3be470bc8ebc46ee2d92 (diff) | |
download | numpy-5718b3306303b4899c017641a9282714dcf8c9b8.tar.gz |
BUG,MAINT: Ensure masked elements can be tested against nan and inf.
The removal of nan and inf from arrays that are compared using
test routines like assert_array_equal treated the two arrays
separately, which for masked arrays meant that some elements would
not be removed when they should have been. This PR corrects this.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/testing/_private/utils.py | 67 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 12 |
2 files changed, 41 insertions, 38 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index a7935f175..5b7e35366 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -685,7 +685,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test - from numpy.core import array, isnan, isinf, any, inf + from numpy.core import array, isnan, any, inf, ndim x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -695,9 +695,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, def istime(x): return x.dtype.char in "Mm" - def chk_same_position(x_id, y_id, hasval='nan'): - """Handling nan/inf: check that x and y have the nan/inf at the same - locations.""" + def func_assert_same_pos(x, y, func=isnan, hasval='nan'): + """Handling nan/inf: combine results of running func on x and y, + checking that they are True at the same locations.""" + x_id = func(x) + y_id = func(y) + if not any(x_id) and not any(y_id): + return False + try: assert_array_equal(x_id, y_id) except AssertionError: @@ -706,6 +711,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, % (hasval), verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + return x_id if x_id.ndim == 0 else y_id try: cond = (x.shape == () or y.shape == ()) or x.shape == y.shape @@ -718,49 +726,32 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, names=('x', 'y'), precision=precision) raise AssertionError(msg) + flagged = False if isnumber(x) and isnumber(y): - has_nan = has_inf = False if equal_nan: - x_isnan, y_isnan = isnan(x), isnan(y) - # Validate that NaNs are in the same place - has_nan = any(x_isnan) or any(y_isnan) - if has_nan: - chk_same_position(x_isnan, y_isnan, hasval='nan') + flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') if equal_inf: - x_isinf, y_isinf = isinf(x), isinf(y) - # Validate that infinite values are in the same place - has_inf = any(x_isinf) or any(y_isinf) - if has_inf: - # Check +inf and -inf separately, since they are different - chk_same_position(x == +inf, y == +inf, hasval='+inf') - chk_same_position(x == -inf, y == -inf, hasval='-inf') - - if has_nan and has_inf: - x = x[~(x_isnan | x_isinf)] - y = y[~(y_isnan | y_isinf)] - elif has_nan: - x = x[~x_isnan] - y = y[~y_isnan] - elif has_inf: - x = x[~x_isinf] - y = y[~y_isinf] - - # Only do the comparison if actual values are left - if x.size == 0: - return + flagged |= func_assert_same_pos(x, y, + func=lambda xy: xy == +inf, + hasval='+inf') + flagged |= func_assert_same_pos(x, y, + func=lambda xy: xy == -inf, + hasval='-inf') elif istime(x) and istime(y): # If one is datetime64 and the other timedelta64 there is no point if equal_nan and x.dtype.type == y.dtype.type: - x_isnat, y_isnat = isnat(x), isnat(y) + flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT") - if any(x_isnat) or any(y_isnat): - chk_same_position(x_isnat, y_isnat, hasval="NaT") - - if any(x_isnat) or any(y_isnat): - x = x[~x_isnat] - y = y[~y_isnat] + if ndim(flagged): + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return val = comparison(x, y) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 602cdf5f2..235b3750c 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -151,6 +151,17 @@ class TestArrayEqual(_GenericTest): self._test_not_equal(c, b) assert_equal(len(l), 1) + def test_masked_nan_inf(self): + # Regression test for gh-11121 + a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False]) + b = np.array([3., np.nan, 6.5]) + self._test_equal(a, b) + self._test_equal(b, a) + a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False]) + b = np.array([np.inf, 4., 6.5]) + self._test_equal(a, b) + self._test_equal(b, a) + class TestBuildErrorMessage(object): @@ -650,6 +661,7 @@ class TestArrayAssertLess(object): assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x)) self._assert_func(-ainf, x) + @pytest.mark.skip(reason="The raises decorator depends on Nose") class TestRaises(object): |