diff options
author | Stephan Hoyer <shoyer@gmail.com> | 2019-01-01 22:50:29 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-01 22:50:29 -0800 |
commit | a16fc9499eaa7cc9d7532f8a51725c6ed647cd1b (patch) | |
tree | b9df23e3024937390d00135616379fc266722339 /numpy | |
parent | 43298265ab35b82e29ff772c466872a78531fabd (diff) | |
download | numpy-a16fc9499eaa7cc9d7532f8a51725c6ed647cd1b.tar.gz |
ENH: add "max difference" messages to np.testing.assert_array_equal (#12591)
Example behavior:
>>> x = np.array([1, 2, 3])
>>> y = np.array([1, 2, 3.0001])
>>> np.testing.assert_allclose(x, y)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatch: 33.3%
Max absolute difference: 0.0001
Max relative difference: 3.33322223e-05
x: array([1, 2, 3])
y: array([1. , 2. , 3.0001])
Motivation: when writing numerical algorithms, I frequently find myself
experimenting to pick the right value of `atol` and `rtol` for
`np.testing.assert_allclose()`. If I make the tolerance too generous, I risk
missing regressions in accuracy, so I usually try to pick the smallest values
for which tests pass. This change immediately reveals appropriate values to
use for these parameters, so I don't need to guess and check.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/testing/_private/nosetester.py | 2 | ||||
-rw-r--r-- | numpy/testing/_private/utils.py | 203 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 115 |
3 files changed, 185 insertions, 135 deletions
diff --git a/numpy/testing/_private/nosetester.py b/numpy/testing/_private/nosetester.py index 1728d9d1f..19569a509 100644 --- a/numpy/testing/_private/nosetester.py +++ b/numpy/testing/_private/nosetester.py @@ -92,7 +92,7 @@ def run_module_suite(file_to_run=None, argv=None): Alternatively, calling:: - >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py") + >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py") # doctest: +SKIP from an interpreter will run all the test routine in 'test_matlib.py'. """ diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 4059f6ee6..1f7b516b3 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -318,8 +318,9 @@ def assert_equal(actual, desired, err_msg='', verbose=True): Examples -------- >>> np.testing.assert_equal([4,5], [4,6]) - ... - <type 'exceptions.AssertionError'>: + Traceback (most recent call last): + ... + AssertionError: Items are not equal: item=1 ACTUAL: 5 @@ -510,20 +511,24 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): >>> import numpy.testing as npt >>> npt.assert_almost_equal(2.3333333333333, 2.33333334) >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) - ... - <type 'exceptions.AssertionError'>: - Items are not equal: - ACTUAL: 2.3333333333333002 - DESIRED: 2.3333333399999998 + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]), ... np.array([1.0,2.33333334]), decimal=9) - ... - <type 'exceptions.AssertionError'>: - Arrays are not almost equal - (mismatch 50.0%) - x: array([ 1. , 2.33333333]) - y: array([ 1. , 2.33333334]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + Mismatch: 50% + Max absolute difference: 6.66669964e-09 + Max relative difference: 2.85715698e-09 + x: array([1. , 2.333333333]) + y: array([1. , 2.33333334]) """ __tracebackhide__ = True # Hide traceback for py.test @@ -625,14 +630,15 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): -------- >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, - significant=8) + ... significant=8) >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, - significant=8) - ... - <type 'exceptions.AssertionError'>: + ... significant=8) + Traceback (most recent call last): + ... + AssertionError: Items are not equal to 8 significant digits: - ACTUAL: 1.234567e-021 - DESIRED: 1.2345672000000001e-021 + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 the evaluated condition that raises the exception is @@ -659,10 +665,10 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): sc_actual = actual/scale except ZeroDivisionError: sc_actual = 0.0 - msg = build_err_msg([actual, desired], err_msg, - header='Items are not equal to %d significant digits:' % - significant, - verbose=verbose) + msg = build_err_msg( + [actual, desired], err_msg, + header='Items are not equal to %d significant digits:' % significant, + verbose=verbose) try: # If one of desired/actual is not finite, handle it specially here: # check that both are nan if any is a nan, and test for equality @@ -685,7 +691,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test - from numpy.core import array, isnan, inf, bool_ + from numpy.core import array, array2string, isnan, inf, bool_, errstate x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -781,15 +787,31 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, reduced = val.ravel() cond = reduced.all() reduced = reduced.tolist() + # The below comparison is a hack to ensure that fully masked # results, for which val.ravel().all() returns np.ma.masked, # do not trigger a failure (np.ma.masked != True evaluates as # np.ma.masked, which is falsy). if cond != True: mismatch = 100.0 * reduced.count(0) / ox.size - msg = build_err_msg([ox, oy], - err_msg - + '\n(mismatch %s%%)' % (mismatch,), + remarks = ['Mismatch: {:.3g}%'.format(mismatch)] + + with errstate(invalid='ignore', divide='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError): + error = abs(x - y) + max_abs_error = error.max() + remarks.append('Max absolute difference: ' + + array2string(max_abs_error)) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + max_rel_error = (error / abs(y)).max() + remarks.append('Max relative difference: ' + + array2string(max_rel_error)) + + err_msg += '\n' + '\n'.join(remarks) + msg = build_err_msg([ox, oy], err_msg, verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) @@ -849,13 +871,15 @@ def assert_array_equal(x, y, err_msg='', verbose=True): >>> np.testing.assert_array_equal([1.0,np.pi,np.nan], ... [1, np.sqrt(np.pi)**2, np.nan]) - ... - <type 'exceptions.ValueError'>: + Traceback (most recent call last): + ... AssertionError: Arrays are not equal - (mismatch 50.0%) - x: array([ 1. , 3.14159265, NaN]) - y: array([ 1. , 3.14159265, NaN]) + Mismatch: 33.3% + Max absolute difference: 4.4408921e-16 + Max relative difference: 1.41357986e-16 + x: array([1. , 3.141593, nan]) + y: array([1. , 3.141593, nan]) Use `assert_allclose` or one of the nulp (number of floating point values) functions for these cases instead: @@ -920,25 +944,29 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): the first assert does not raise an exception >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], - [1.0,2.333,np.nan]) + ... [1.0,2.333,np.nan]) >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], ... [1.0,2.33339,np.nan], decimal=5) - ... - <type 'exceptions.AssertionError'>: + Traceback (most recent call last): + ... AssertionError: - Arrays are not almost equal - (mismatch 50.0%) - x: array([ 1. , 2.33333, NaN]) - y: array([ 1. , 2.33339, NaN]) + Arrays are not almost equal to 5 decimals + Mismatch: 33.3% + Max absolute difference: 6.e-05 + Max relative difference: 2.57136612e-05 + x: array([1. , 2.33333, nan]) + y: array([1. , 2.33339, nan]) >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], ... [1.0,2.33333, 5], decimal=5) - <type 'exceptions.ValueError'>: - ValueError: - Arrays are not almost equal - x: array([ 1. , 2.33333, NaN]) - y: array([ 1. , 2.33333, 5. ]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + x and y nan location mismatch: + x: array([1. , 2.33333, nan]) + y: array([1. , 2.33333, 5. ]) """ __tracebackhide__ = True # Hide traceback for py.test @@ -1019,27 +1047,34 @@ def assert_array_less(x, y, err_msg='', verbose=True): -------- >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) - ... - <type 'exceptions.ValueError'>: + Traceback (most recent call last): + ... + AssertionError: Arrays are not less-ordered - (mismatch 50.0%) - x: array([ 1., 1., NaN]) - y: array([ 1., 2., NaN]) + Mismatch: 33.3% + Max absolute difference: 1. + Max relative difference: 0.5 + x: array([ 1., 1., nan]) + y: array([ 1., 2., nan]) >>> np.testing.assert_array_less([1.0, 4.0], 3) - ... - <type 'exceptions.ValueError'>: + Traceback (most recent call last): + ... + AssertionError: Arrays are not less-ordered - (mismatch 50.0%) - x: array([ 1., 4.]) + Mismatch: 50% + Max absolute difference: 2. + Max relative difference: 0.66666667 + x: array([1., 4.]) y: array(3) >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) - ... - <type 'exceptions.ValueError'>: + Traceback (most recent call last): + ... + AssertionError: Arrays are not less-ordered (shapes (3,), (1,) mismatch) - x: array([ 1., 2., 3.]) + x: array([1., 2., 3.]) y: array([4]) """ @@ -1144,7 +1179,7 @@ def rundocs(filename=None, raise_on_error=True): argument to the ``test()`` call. For example, to run all tests (including doctests) for `numpy.lib`: - >>> np.lib.test(doctests=True) #doctest: +SKIP + >>> np.lib.test(doctests=True) # doctest: +SKIP """ from numpy.compat import npy_load_module import doctest @@ -1326,7 +1361,7 @@ def decorate_methods(cls, decorator, testmatch=None): return -def measure(code_str,times=1,label=None): +def measure(code_str, times=1, label=None): """ Return elapsed time for executing code in the namespace of the caller. @@ -1353,9 +1388,9 @@ def measure(code_str,times=1,label=None): Examples -------- - >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', - ... times=times) - >>> print("Time for a single execution : ", etime / times, "s") + >>> times = 10 + >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times) + >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP Time for a single execution : 0.005 s """ @@ -1440,7 +1475,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, -------- >>> x = [1e-5, 1e-3, 1e-1] >>> y = np.arccos(np.cos(x)) - >>> assert_allclose(x, y, rtol=1e-5, atol=0) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) """ __tracebackhide__ = True # Hide traceback for py.test @@ -1894,7 +1929,8 @@ class clear_and_catch_warnings(warnings.catch_warnings): Examples -------- >>> import warnings - >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]): + >>> with np.testing.clear_and_catch_warnings( + ... modules=[np.core.fromnumeric]): ... warnings.simplefilter('always') ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') ... # do something that raises a warning but ignore those in @@ -1975,25 +2011,28 @@ class suppress_warnings(object): Examples -------- - >>> with suppress_warnings() as sup: - ... sup.filter(DeprecationWarning, "Some text") - ... sup.filter(module=np.ma.core) - ... log = sup.record(FutureWarning, "Does this occur?") - ... command_giving_warnings() - ... # The FutureWarning was given once, the filtered warnings were - ... # ignored. All other warnings abide outside settings (may be - ... # printed/error) - ... assert_(len(log) == 1) - ... assert_(len(sup.log) == 1) # also stored in log attribute - - Or as a decorator: - - >>> sup = suppress_warnings() - >>> sup.filter(module=np.ma.core) # module must match exact - >>> @sup - >>> def some_function(): - ... # do something which causes a warning in np.ma.core - ... pass + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass """ def __init__(self, forwarding_rule="always"): self._entered = False diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 43afafaa8..c376a3852 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -327,24 +327,22 @@ class TestEqual(TestArrayEqual): self._test_not_equal(x, y) def test_error_message(self): - try: + with pytest.raises(AssertionError) as exc_info: self._assert_func(np.array([1, 2]), np.array([[1, 2]])) - except AssertionError as e: - msg = str(e) - msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)") - msg_reference = textwrap.dedent("""\ + msg = str(exc_info.value) + msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)") + msg_reference = textwrap.dedent("""\ - Arrays are not equal + Arrays are not equal - (shapes (2,), (1, 2) mismatch) - x: array([1, 2]) - y: array([[1, 2]])""") - try: - assert_equal(msg, msg_reference) - except AssertionError: - assert_equal(msg2, msg_reference) - else: - raise AssertionError("Did not raise") + (shapes (2,), (1, 2) mismatch) + x: array([1, 2]) + y: array([[1, 2]])""") + + try: + assert_equal(msg, msg_reference) + except AssertionError: + assert_equal(msg2, msg_reference) class TestArrayAlmostEqual(_GenericTest): @@ -509,38 +507,53 @@ class TestAlmostEqual(_GenericTest): x = np.array([1.00000000001, 2.00000000002, 3.00003]) y = np.array([1.00000000002, 2.00000000003, 3.00004]) - # test with a different amount of decimal digits - # note that we only check for the formatting of the arrays themselves - b = ('x: array([1.00000000001, 2.00000000002, 3.00003 ' - ' ])\n y: array([1.00000000002, 2.00000000003, 3.00004 ])') - try: + # Test with a different amount of decimal digits + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y, decimal=12) - except AssertionError as e: - # remove anything that's not the array string - assert_equal(str(e).split('%)\n ')[1], b) - - # with the default value of decimal digits, only the 3rd element differs - # note that we only check for the formatting of the arrays themselves - b = ('x: array([1. , 2. , 3.00003])\n y: array([1. , ' - '2. , 3.00004])') - try: + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 100%') + assert_equal(msgs[4], 'Max absolute difference: 1.e-05') + assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06') + assert_equal( + msgs[6], + ' x: array([1.00000000001, 2.00000000002, 3.00003 ])') + assert_equal( + msgs[7], + ' y: array([1.00000000002, 2.00000000003, 3.00004 ])') + + # With the default value of decimal digits, only the 3rd element + # differs. Note that we only check for the formatting of the arrays + # themselves. + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y) - except AssertionError as e: - # remove anything that's not the array string - assert_equal(str(e).split('%)\n ')[1], b) - - # Check the error message when input includes inf or nan + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 33.3%') + assert_equal(msgs[4], 'Max absolute difference: 1.e-05') + assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06') + assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])') + assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])') + + # Check the error message when input includes inf x = np.array([np.inf, 0]) y = np.array([np.inf, 1]) - try: + with pytest.raises(AssertionError) as exc_info: + self._assert_func(x, y) + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 50%') + assert_equal(msgs[4], 'Max absolute difference: 1.') + assert_equal(msgs[5], 'Max relative difference: 1.') + assert_equal(msgs[6], ' x: array([inf, 0.])') + assert_equal(msgs[7], ' y: array([inf, 1.])') + + # Check the error message when dividing by zero + x = np.array([1, 2]) + y = np.array([0, 0]) + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y) - except AssertionError as e: - msgs = str(e).split('\n') - # assert error percentage is 50% - assert_equal(msgs[3], '(mismatch 50.0%)') - # assert output array contains inf - assert_equal(msgs[4], ' x: array([inf, 0.])') - assert_equal(msgs[5], ' y: array([inf, 1.])') + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 100%') + assert_equal(msgs[4], 'Max absolute difference: 2') + assert_equal(msgs[5], 'Max relative difference: inf') def test_subclass_that_cannot_be_bool(self): # While we cannot guarantee testing functions will always work for @@ -829,12 +842,12 @@ class TestAssertAllclose(object): def test_report_fail_percentage(self): a = np.array([1, 1, 1, 1]) b = np.array([1, 1, 1, 2]) - try: + + with pytest.raises(AssertionError) as exc_info: assert_allclose(a, b) - msg = '' - except AssertionError as exc: - msg = exc.args[0] - assert_("mismatch 25.0%" in msg) + msg = str(exc_info.value) + assert_('Mismatch: 25%\nMax absolute difference: 1\n' + 'Max relative difference: 0.5' in msg) def test_equal_nan(self): a = np.array([np.nan]) @@ -1117,12 +1130,10 @@ class TestStringEqual(object): assert_string_equal("hello", "hello") assert_string_equal("hello\nmultiline", "hello\nmultiline") - try: + with pytest.raises(AssertionError) as exc_info: assert_string_equal("foo\nbar", "hello\nbar") - except AssertionError as exc: - assert_equal(str(exc), "Differences in strings:\n- foo\n+ hello") - else: - raise AssertionError("exception not raised") + msg = str(exc_info.value) + assert_equal(msg, "Differences in strings:\n- foo\n+ hello") assert_raises(AssertionError, lambda: assert_string_equal("foo", "hello")) |