summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
authorAntti Kaihola <antti.kaihola@eniram.fi>2016-10-12 23:17:23 +0300
committerCharles Harris <charlesr.harris@gmail.com>2016-10-17 16:19:28 -0600
commita233689a9837a4aeb71efe144eae578aa22ca58f (patch)
treebab0a700762cc91262803e35e9be05ec8b096279 /numpy/testing/utils.py
parent8077d85832ba45d24c62dcae0111b9320de1602c (diff)
downloadnumpy-a233689a9837a4aeb71efe144eae578aa22ca58f.tar.gz
BUG: Make assert_allclose(..., equal_nan=False) work.
As discussed in my comments for issue #8145, this patch adds the equal_nan argument to assert_array_compare(), and assert_allclose() passes the value it receives for the same argument through to assert_array_compare(). Closes #8142.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 599e73cb0..b01de173d 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -666,7 +666,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header='', precision=6):
+ header='', precision=6, equal_nan=True):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
@@ -724,21 +724,25 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
if isnumber(x) and isnumber(y):
- x_isnan, y_isnan = isnan(x), isnan(y)
+ if equal_nan:
+ x_isnan, y_isnan = isnan(x), isnan(y)
+ # Validate that NaNs are in the same place
+ if any(x_isnan) or any(y_isnan):
+ chk_same_position(x_isnan, y_isnan, hasval='nan')
+
x_isinf, y_isinf = isinf(x), isinf(y)
- # Validate that the special values are in the same place
- if any(x_isnan) or any(y_isnan):
- chk_same_position(x_isnan, y_isnan, hasval='nan')
+ # Validate that infinite values are in the same place
if any(x_isinf) or any(y_isinf):
# Check +inf and -inf separately, since they are different
chk_same_position(x == +inf, y == +inf, hasval='+inf')
chk_same_position(x == -inf, y == -inf, hasval='-inf')
# Combine all the special values
- x_id, y_id = x_isnan, y_isnan
- x_id |= x_isinf
- y_id |= y_isinf
+ x_id, y_id = x_isinf, y_isinf
+ if equal_nan:
+ x_id |= x_isnan
+ y_id |= y_isnan
# Only do the comparison if actual values are left
if all(x_id):
@@ -1381,7 +1385,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False,
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
- verbose=verbose, header=header)
+ verbose=verbose, header=header, equal_nan=equal_nan)
def assert_array_almost_equal_nulp(x, y, nulp=1):