summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/tests/test_utils.py23
-rw-r--r--numpy/testing/utils.py38
2 files changed, 60 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index fabf7a2e0..5106c1184 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -336,6 +336,29 @@ class TestWarns(unittest.TestCase):
if failed:
raise AssertionError("wrong warning caught by assert_warn")
+class TestAssertAllclose(unittest.TestCase):
+ def test_simple(self):
+ x = 1e-3
+ y = 1e-9
+
+ assert_allclose(x, y, atol=1)
+ self.assertRaises(AssertionError, assert_allclose, x, y)
+
+ a = np.array([x, y, x, y])
+ b = np.array([x, y, x, x])
+
+ assert_allclose(a, b, atol=1)
+ self.assertRaises(AssertionError, assert_allclose, a, b)
+
+ b[-1] = y * (1 + 1e-8)
+ assert_allclose(a, b)
+ self.assertRaises(AssertionError, assert_allclose, a, b,
+ rtol=1e-9)
+
+ assert_allclose(6, 10, rtol=0.5)
+ self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
+
+
class TestArrayAlmostEqualNulp(unittest.TestCase):
def test_simple(self):
dev = np.random.randn(10)
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 875027cd8..6f18b5468 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -16,7 +16,7 @@ __all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal',
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
'assert_', 'assert_array_almost_equal_nulp',
- 'assert_array_max_ulp', 'assert_warns']
+ 'assert_array_max_ulp', 'assert_warns', 'assert_allclose']
verbose = 0
@@ -1093,6 +1093,42 @@ def _assert_valid_refcount(op):
assert(sys.getrefcount(i) >= rc)
+def assert_allclose(actual, desired, rtol=1e-7, atol=0,
+ err_msg='', verbose=True):
+ """
+ Raise an assertion if two objects are not equal up to desired tolerance.
+
+ The test is equivalent to ``allclose(actual, desired, rtol, atol)``
+
+ Parameters
+ ----------
+ actual : array_like
+ Array obtained.
+ desired : array_like
+ Array desired
+ rtol : float, optional
+ Relative tolerance
+ atol : float, optional
+ Absolute tolerance
+ err_msg : string
+ The error message to be printed in case of failure.
+ verbose : bool
+ If True, the conflicting values are appended to the error message.
+
+ Raises
+ ------
+ AssertionError
+ If actual and desired are not equal up to specified precision.
+
+ """
+ import numpy as np
+ def compare(x, y):
+ return np.allclose(x, y, rtol=rtol, atol=atol)
+ actual, desired = np.asanyarray(actual), np.asanyarray(desired)
+ header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
+ assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
+ verbose=verbose, header=header)
+
def assert_array_almost_equal_nulp(x, y, nulp=1):
"""
Compare two arrays relatively to their spacing.