diff options
author | Antti Kaihola <antti.kaihola@eniram.fi> | 2016-10-12 23:17:23 +0300 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2016-10-17 16:19:28 -0600 |
commit | a233689a9837a4aeb71efe144eae578aa22ca58f (patch) | |
tree | bab0a700762cc91262803e35e9be05ec8b096279 /numpy/testing/utils.py | |
parent | 8077d85832ba45d24c62dcae0111b9320de1602c (diff) | |
download | numpy-a233689a9837a4aeb71efe144eae578aa22ca58f.tar.gz |
BUG: Make assert_allclose(..., equal_nan=False) work.
As discussed in my comments for issue #8145, this patch adds the
equal_nan argument to assert_array_compare(), and assert_allclose()
passes the value it receives for the same argument through to
assert_array_compare().
Closes #8142.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 599e73cb0..b01de173d 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -666,7 +666,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header='', precision=6): + header='', precision=6, equal_nan=True): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) @@ -724,21 +724,25 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, raise AssertionError(msg) if isnumber(x) and isnumber(y): - x_isnan, y_isnan = isnan(x), isnan(y) + if equal_nan: + x_isnan, y_isnan = isnan(x), isnan(y) + # Validate that NaNs are in the same place + if any(x_isnan) or any(y_isnan): + chk_same_position(x_isnan, y_isnan, hasval='nan') + x_isinf, y_isinf = isinf(x), isinf(y) - # Validate that the special values are in the same place - if any(x_isnan) or any(y_isnan): - chk_same_position(x_isnan, y_isnan, hasval='nan') + # Validate that infinite values are in the same place if any(x_isinf) or any(y_isinf): # Check +inf and -inf separately, since they are different chk_same_position(x == +inf, y == +inf, hasval='+inf') chk_same_position(x == -inf, y == -inf, hasval='-inf') # Combine all the special values - x_id, y_id = x_isnan, y_isnan - x_id |= x_isinf - y_id |= y_isinf + x_id, y_id = x_isinf, y_isinf + if equal_nan: + x_id |= x_isnan + y_id |= y_isnan # Only do the comparison if actual values are left if all(x_id): @@ -1381,7 +1385,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False, actual, desired = np.asanyarray(actual), np.asanyarray(desired) header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) assert_array_compare(compare, actual, desired, err_msg=str(err_msg), - verbose=verbose, header=header) + verbose=verbose, header=header, equal_nan=equal_nan) def assert_array_almost_equal_nulp(x, y, nulp=1): |