diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 68 | ||||
-rw-r--r-- | numpy/testing/utils.py | 28 |
2 files changed, 86 insertions, 10 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 5956a4294..aa0a2669f 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -134,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase): self._test_not_equal(c, b) +class TestBuildErrorMessage(unittest.TestCase): + def test_build_err_msg_defaults(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, ' + '2.00003, 3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_no_verbose(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, verbose=False) + b = '\nItems are not equal: There is a mismatch' + self.assertEqual(a, b) + + def test_build_err_msg_custom_names(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR')) + b = ('\nItems are not equal: There is a mismatch\n FOO: array([ ' + '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, ' + '3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_custom_precision(self): + x = np.array([1.000000001, 2.00002, 3.00003]) + y = np.array([1.000000002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, precision=10) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ ' + '1.000000002, 2.00003 , 3.00004 ])') + self.assertEqual(a, b) + class TestEqual(TestArrayEqual): def setUp(self): self._assert_func = assert_equal @@ -239,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase): self._test_not_equal(x, y) self._test_not_equal(x, z) + def test_error_message(self): + """Check the message is formatted correctly for the decimal value""" + x = np.array([1.00000000001, 2.00000000002, 3.00003]) + y = np.array([1.00000000002, 2.00000000003, 3.00004]) + + # test with a different amount of decimal digits + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 ' + ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])') + try: + self._assert_func(x, y, decimal=12) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + + # with the default value of decimal digits, only the 3rd element differs + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , ' + '2. , 3.00004])') + try: + self._assert_func(x, y) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + class TestApproxEqual(unittest.TestCase): def setUp(self): self._assert_func = assert_approx_equal diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 01c0e769b..930900bf2 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,8 +9,9 @@ import sys import re import operator import warnings +from functools import partial from .nosetester import import_nose -from numpy.core import float32, empty, arange +from numpy.core import float32, empty, arange, array_repr, ndarray if sys.version_info[0] >= 3: from io import StringIO @@ -190,8 +191,7 @@ if os.name=='nt' and sys.version[:3] > '2.3': win32pdh.PDH_FMT_LONG, None) def build_err_msg(arrays, err_msg, header='Items are not equal:', - verbose=True, - names=('ACTUAL', 'DESIRED')): + verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg.append(err_msg) if verbose: for i, a in enumerate(arrays): + + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + r_func = partial(array_repr, precision=precision) + else: + r_func = repr + try: - r = repr(a) + r = r_func(a) except: r = '[repr failed]' if r.count('\n') > 3: @@ -575,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header=''): + header='', precision=6): from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -592,7 +599,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' \ % (hasval), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise AssertionError(msg) try: @@ -603,7 +610,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) @@ -648,7 +655,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, err_msg + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) except ValueError as e: @@ -657,7 +664,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise ValueError(msg) def assert_array_equal(x, y, err_msg='', verbose=True): @@ -825,7 +832,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): return around(z, decimal) <= 10.0**(-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal)) + header=('Arrays are not almost equal to %d decimals' % decimal), + precision=decimal) def assert_array_less(x, y, err_msg='', verbose=True): |