diff options
author | rgommers <ralf.gommers@googlemail.com> | 2011-03-12 19:58:10 +0800 |
---|---|---|
committer | rgommers <ralf.gommers@googlemail.com> | 2011-03-12 20:08:45 +0800 |
commit | 45269ee1c906269bc3cd481f7ed94c8100e2b858 (patch) | |
tree | b56b84875caef3459575ee75e000c1373b58c714 /numpy/testing/utils.py | |
parent | 0de7890ac6cad3d04e4cb32fe8431950eda065fa (diff) | |
download | numpy-45269ee1c906269bc3cd481f7ed94c8100e2b858.tar.gz |
BUG: fix assert_almost_equal and co. to work with infs.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 0c0483f39..01ce31c4a 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -568,13 +568,25 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header=''): - from numpy.core import array, isnan, any + from numpy.core import array, isnan, isinf, any x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) def isnumber(x): return x.dtype.char in '?bhilqpBHILQPfdgFDG' + def chk_same_position(x_id, y_id, hasval='nan'): + """Handling nan/inf: check that x and y have the nan/inf at the same + locations.""" + try: + assert_array_equal(x_id, y_id) + except AssertionError: + msg = build_err_msg([x, y], + err_msg + '\nx and y %s location mismatch:' \ + % (hasval), verbose=verbose, header=header, + names=('x', 'y')) + raise AssertionError(msg) + try: cond = (x.shape==() or y.shape==()) or x.shape == y.shape if not cond: @@ -588,27 +600,24 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, raise AssertionError(msg) if (isnumber(x) and isnumber(y)) and (any(isnan(x)) or any(isnan(y))): - # Handling nan: we first check that x and y have the nan at the - # same locations, and then we mask the nan and do the comparison as - # usual. - xnanid = isnan(x) - ynanid = isnan(y) - try: - assert_array_equal(xnanid, ynanid) - except AssertionError: - msg = build_err_msg([x, y], - err_msg - + '\n(x and y nan location mismatch %s, ' \ - '%s mismatch)' % (xnanid, ynanid), - verbose=verbose, header=header, - names=('x', 'y')) - raise AssertionError(msg) + x_id = isnan(x) + y_id = isnan(y) + chk_same_position(x_id, y_id, hasval='nan') # If only one item, it was a nan, so just return if x.size == y.size == 1: return - val = comparison(x[~xnanid], y[~ynanid]) + val = comparison(x[~x_id], y[~y_id]) + elif (isnumber(x) and isnumber(y)) and (any(isinf(x)) or any(isinf(y))): + x_id = isinf(x) + y_id = isinf(y) + chk_same_position(x_id, y_id, hasval='inf') + # If only one item, it was a inf, so just return + if x.size == y.size == 1: + return + val = comparison(x[~x_id], y[~y_id]) else: val = comparison(x,y) + if isinstance(val, bool): cond = val reduced = [0] |