summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-10-17 19:16:04 -0400
committerGitHub <noreply@github.com>2016-10-17 19:16:04 -0400
commit15e82e2181d9ae92b6a1a6fb80454bc3a4aa2cb5 (patch)
treee46a481305ff27c442c62d629c2da789804dbbff /numpy/testing/utils.py
parent6d05b0cf555e454ca3858ceaf890274e0b079b0b (diff)
parenta233689a9837a4aeb71efe144eae578aa22ca58f (diff)
downloadnumpy-15e82e2181d9ae92b6a1a6fb80454bc3a4aa2cb5.tar.gz
Merge pull request #8165 from charris/fixup-8152
Fixup 8152, BUG: assert_allclose(..., equal_nan=False) doesn't raise AssertionError
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 599e73cb0..b01de173d 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -666,7 +666,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header='', precision=6):
+ header='', precision=6, equal_nan=True):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
@@ -724,21 +724,25 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
if isnumber(x) and isnumber(y):
- x_isnan, y_isnan = isnan(x), isnan(y)
+ if equal_nan:
+ x_isnan, y_isnan = isnan(x), isnan(y)
+ # Validate that NaNs are in the same place
+ if any(x_isnan) or any(y_isnan):
+ chk_same_position(x_isnan, y_isnan, hasval='nan')
+
x_isinf, y_isinf = isinf(x), isinf(y)
- # Validate that the special values are in the same place
- if any(x_isnan) or any(y_isnan):
- chk_same_position(x_isnan, y_isnan, hasval='nan')
+ # Validate that infinite values are in the same place
if any(x_isinf) or any(y_isinf):
# Check +inf and -inf separately, since they are different
chk_same_position(x == +inf, y == +inf, hasval='+inf')
chk_same_position(x == -inf, y == -inf, hasval='-inf')
# Combine all the special values
- x_id, y_id = x_isnan, y_isnan
- x_id |= x_isinf
- y_id |= y_isinf
+ x_id, y_id = x_isinf, y_isinf
+ if equal_nan:
+ x_id |= x_isnan
+ y_id |= y_isnan
# Only do the comparison if actual values are left
if all(x_id):
@@ -1381,7 +1385,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False,
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
- verbose=verbose, header=header)
+ verbose=verbose, header=header, equal_nan=equal_nan)
def assert_array_almost_equal_nulp(x, y, nulp=1):