diff options
author | cookedm <cookedm@localhost> | 2006-06-14 01:21:22 +0000 |
---|---|---|
committer | cookedm <cookedm@localhost> | 2006-06-14 01:21:22 +0000 |
commit | 6a4ff2dcd299ec2cc154726fbd04cad284edbb10 (patch) | |
tree | f0ed5f8a32eae48084c5c35cd82b5ba6e1596036 /numpy/testing/utils.py | |
parent | b65ac291abf9fb443f7e581635ee3ffe04415e66 (diff) | |
download | numpy-6a4ff2dcd299ec2cc154726fbd04cad284edbb10.tar.gz |
Rework numpy.testing.utils.
This tightens up equality tests a bit; some tests in numpy an scipy fail.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 138 |
1 files changed, 52 insertions, 86 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index fcffc9002..11e3b95d3 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -4,6 +4,7 @@ Utility function to facilitate testing. import os import sys +import operator __all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal', 'assert_array_equal', 'assert_array_less', @@ -98,8 +99,9 @@ if os.name=='nt' and sys.version[:3] > '2.3': processName, instance, win32pdh.PDH_FMT_LONG, None) -def build_err_msg(actual, desired, err_msg, header='Items are not equal:', - verbose=True): +def build_err_msg(arrays, err_msg, header='Items are not equal:', + verbose=True, + names=('ACTUAL', 'DESIRED')): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -107,18 +109,13 @@ def build_err_msg(actual, desired, err_msg, header='Items are not equal:', else: msg.append(err_msg) if verbose: - try: - rd = repr(desired) - except: - rd = '[repr failed]' - rd = rd[:100] - msg.append(' DESIRED: ' + rd) - try: - ra = repr(actual) - except: - ra = '[repr failed]' - rd = ra[:100] - msg.append(' ACTUAL: ' + ra) + for i, a in enumerate(arrays): + try: + r = repr(a) + except: + r = '[repr failed]' + r = r[:100] + msg.append(' %s: %s' % (names[i], r)) return '\n'.join(msg) def assert_equal(actual,desired,err_msg='',verbose=True): @@ -140,7 +137,7 @@ def assert_equal(actual,desired,err_msg='',verbose=True): from numpy.core import ndarray if isinstance(actual, ndarray) or isinstance(desired, ndarray): return assert_array_equal(actual, desired, err_msg) - msg = build_err_msg(actual, desired, err_msg, verbose=verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) assert desired == actual, msg def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): @@ -153,7 +150,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): from numpy.core import ndarray if isinstance(actual, ndarray) or isinstance(desired, ndarray): return assert_array_almost_equal(actual, desired, decimal, err_msg) - msg = build_err_msg(actual, desired, err_msg, verbose=verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) assert round(abs(desired - actual),decimal) == 0, msg @@ -177,88 +174,57 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): sc_actual = actual/scale except ZeroDivisionError: sc_actual = 0.0 - msg = build_err_msg(actual, desired, err_msg, + msg = build_err_msg([actual, desired], err_msg, header='Items are not equal to %d significant digits:' % significant, verbose=verbose) assert math.fabs(sc_desired - sc_actual) < pow(10.,-1*significant), msg - -def assert_array_equal(x,y,err_msg=''): - from numpy.core import asarray, alltrue, equal, shape, ravel, array2string - x,y = asarray(x), asarray(y) - msg = '\nArrays are not equal' - try: - assert 0 in [len(shape(x)),len(shape(y))] \ - or (len(shape(x))==len(shape(y)) and \ - alltrue(equal(shape(x),shape(y)))),\ - msg + ' (shapes %s, %s mismatch):\n\t' \ - % (shape(x),shape(y)) + err_msg - reduced = ravel(equal(x,y)) - cond = alltrue(reduced) - if not cond: - s1 = array2string(x,precision=16) - s2 = array2string(y,precision=16) - if len(s1)>120: s1 = s1[:120] + '...' - if len(s2)>120: s2 = s2[:120] + '...' - match = 100-100.0*reduced.tolist().count(1)/len(reduced) - msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2) - assert cond,\ - msg + '\n\t' + err_msg - except ValueError: - raise ValueError, msg - - -def assert_array_almost_equal(x,y,decimal=6,err_msg=''): - from numpy.core import asarray, alltrue, equal, shape, ravel,\ - array2string, less_equal, around +def assert_array_compare(comparision, x, y, err_msg='', verbose=True, + header=''): + from numpy.core import asarray x = asarray(x) y = asarray(y) - msg = '\nArrays are not almost equal' try: - cond = alltrue(equal(shape(x),shape(y))) - if not cond: - msg = msg + ' (shapes mismatch):\n\t'\ - 'Shape of array 1: %s\n\tShape of array 2: %s' % (shape(x),shape(y)) - assert cond, msg + '\n\t' + err_msg - reduced = ravel(equal(less_equal(around(abs(x-y),decimal),10.0**(-decimal)),1)) - cond = alltrue(reduced) + cond = (x.shape==() or y.shape==()) or x.shape == y.shape if not cond: - s1 = array2string(x,precision=decimal+1) - s2 = array2string(y,precision=decimal+1) - if len(s1)>120: s1 = s1[:120] + '...' - if len(s2)>120: s2 = s2[:120] + '...' - match = 100-100.0*reduced.tolist().count(1)/len(reduced) - msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2) - assert cond,\ - msg + '\n\t' + err_msg - except ValueError: - print sys.exc_value - print shape(x),shape(y) - print x, y - raise ValueError, 'arrays are not almost equal' - -def assert_array_less(x,y,err_msg=''): - from numpy.core import asarray, alltrue, less, equal, shape, ravel, array2string - x,y = asarray(x), asarray(y) - msg = '\nArrays are not less-ordered' - try: - assert alltrue(equal(shape(x),shape(y))),\ - msg + ' (shapes mismatch):\n\t' + err_msg - reduced = ravel(less(x,y)) - cond = alltrue(reduced) + msg = build_err_msg([x, y], + err_msg + + '\n(shapes %s, %s mismatch)' % (x.shape, + y.shape), + verbose=verbose, header=header, + names=('x', 'y')) + assert cond, msg + reduced = comparision(x, y).ravel() + cond = reduced.all() if not cond: - s1 = array2string(x,precision=16) - s2 = array2string(y,precision=16) - if len(s1)>120: s1 = s1[:120] + '...' - if len(s2)>120: s2 = s2[:120] + '...' match = 100-100.0*reduced.tolist().count(1)/len(reduced) - msg = msg + ' (mismatch %s%%):\n\tArray 1: %s\n\tArray 2: %s' % (match,s1,s2) - assert cond,\ - msg + '\n\t' + err_msg + msg = build_err_msg([x, y], + err_msg + + '\n(mismatch %s%%)' % (match,), + verbose=verbose, header=header, + names=('x', 'y')) + assert cond, msg except ValueError: - print shape(x),shape(y) - raise ValueError, 'arrays are not less-ordered' + msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, + names=('x', 'y')) + raise ValueError(msg) + +def assert_array_equal(x, y, err_msg='', verbose=True): + assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, + verbose=verbose, header='Arrays are not equal') + +def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): + from numpy.core import around + def compare(x, y): + return around(abs(x-y),decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, + header='Arrays are not almost equal') + +def assert_array_less(x, y, err_msg='', verbose=True): + assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, + verbose=verbose, + header='Arrays are not less-ordered') def runstring(astr, dict): exec astr in dict |