diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 23 | ||||
-rw-r--r-- | numpy/testing/utils.py | 38 |
2 files changed, 60 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index fabf7a2e0..5106c1184 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -336,6 +336,29 @@ class TestWarns(unittest.TestCase): if failed: raise AssertionError("wrong warning caught by assert_warn") +class TestAssertAllclose(unittest.TestCase): + def test_simple(self): + x = 1e-3 + y = 1e-9 + + assert_allclose(x, y, atol=1) + self.assertRaises(AssertionError, assert_allclose, x, y) + + a = np.array([x, y, x, y]) + b = np.array([x, y, x, x]) + + assert_allclose(a, b, atol=1) + self.assertRaises(AssertionError, assert_allclose, a, b) + + b[-1] = y * (1 + 1e-8) + assert_allclose(a, b) + self.assertRaises(AssertionError, assert_allclose, a, b, + rtol=1e-9) + + assert_allclose(6, 10, rtol=0.5) + self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5) + + class TestArrayAlmostEqualNulp(unittest.TestCase): def test_simple(self): dev = np.random.randn(10) diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 875027cd8..6f18b5468 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -16,7 +16,7 @@ __all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', 'assert_', 'assert_array_almost_equal_nulp', - 'assert_array_max_ulp', 'assert_warns'] + 'assert_array_max_ulp', 'assert_warns', 'assert_allclose'] verbose = 0 @@ -1093,6 +1093,42 @@ def _assert_valid_refcount(op): assert(sys.getrefcount(i) >= rc) +def assert_allclose(actual, desired, rtol=1e-7, atol=0, + err_msg='', verbose=True): + """ + Raise an assertion if two objects are not equal up to desired tolerance. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired + rtol : float, optional + Relative tolerance + atol : float, optional + Absolute tolerance + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + """ + import numpy as np + def compare(x, y): + return np.allclose(x, y, rtol=rtol, atol=atol) + actual, desired = np.asanyarray(actual), np.asanyarray(desired) + header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) + assert_array_compare(compare, actual, desired, err_msg=str(err_msg), + verbose=verbose, header=header) + def assert_array_almost_equal_nulp(x, y, nulp=1): """ Compare two arrays relatively to their spacing. |