summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2019-01-01 22:50:29 -0800
committerGitHub <noreply@github.com>2019-01-01 22:50:29 -0800
commita16fc9499eaa7cc9d7532f8a51725c6ed647cd1b (patch)
treeb9df23e3024937390d00135616379fc266722339 /numpy
parent43298265ab35b82e29ff772c466872a78531fabd (diff)
downloadnumpy-a16fc9499eaa7cc9d7532f8a51725c6ed647cd1b.tar.gz
ENH: add "max difference" messages to np.testing.assert_array_equal (#12591)
Example behavior: >>> x = np.array([1, 2, 3]) >>> y = np.array([1, 2, 3.0001]) >>> np.testing.assert_allclose(x, y) AssertionError: Not equal to tolerance rtol=1e-07, atol=0 Mismatch: 33.3% Max absolute difference: 0.0001 Max relative difference: 3.33322223e-05 x: array([1, 2, 3]) y: array([1. , 2. , 3.0001]) Motivation: when writing numerical algorithms, I frequently find myself experimenting to pick the right value of `atol` and `rtol` for `np.testing.assert_allclose()`. If I make the tolerance too generous, I risk missing regressions in accuracy, so I usually try to pick the smallest values for which tests pass. This change immediately reveals appropriate values to use for these parameters, so I don't need to guess and check.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/_private/nosetester.py2
-rw-r--r--numpy/testing/_private/utils.py203
-rw-r--r--numpy/testing/tests/test_utils.py115
3 files changed, 185 insertions, 135 deletions
diff --git a/numpy/testing/_private/nosetester.py b/numpy/testing/_private/nosetester.py
index 1728d9d1f..19569a509 100644
--- a/numpy/testing/_private/nosetester.py
+++ b/numpy/testing/_private/nosetester.py
@@ -92,7 +92,7 @@ def run_module_suite(file_to_run=None, argv=None):
Alternatively, calling::
- >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py")
+ >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py") # doctest: +SKIP
from an interpreter will run all the test routine in 'test_matlib.py'.
"""
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 4059f6ee6..1f7b516b3 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -318,8 +318,9 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
Examples
--------
>>> np.testing.assert_equal([4,5], [4,6])
- ...
- <type 'exceptions.AssertionError'>:
+ Traceback (most recent call last):
+ ...
+ AssertionError:
Items are not equal:
item=1
ACTUAL: 5
@@ -510,20 +511,24 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
>>> import numpy.testing as npt
>>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
>>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
- ...
- <type 'exceptions.AssertionError'>:
- Items are not equal:
- ACTUAL: 2.3333333333333002
- DESIRED: 2.3333333399999998
+ Traceback (most recent call last):
+ ...
+ AssertionError:
+ Arrays are not almost equal to 10 decimals
+ ACTUAL: 2.3333333333333
+ DESIRED: 2.33333334
>>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
... np.array([1.0,2.33333334]), decimal=9)
- ...
- <type 'exceptions.AssertionError'>:
- Arrays are not almost equal
- (mismatch 50.0%)
- x: array([ 1. , 2.33333333])
- y: array([ 1. , 2.33333334])
+ Traceback (most recent call last):
+ ...
+ AssertionError:
+ Arrays are not almost equal to 9 decimals
+ Mismatch: 50%
+ Max absolute difference: 6.66669964e-09
+ Max relative difference: 2.85715698e-09
+ x: array([1. , 2.333333333])
+ y: array([1. , 2.33333334])
"""
__tracebackhide__ = True # Hide traceback for py.test
@@ -625,14 +630,15 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
--------
>>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
- significant=8)
+ ... significant=8)
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
- significant=8)
- ...
- <type 'exceptions.AssertionError'>:
+ ... significant=8)
+ Traceback (most recent call last):
+ ...
+ AssertionError:
Items are not equal to 8 significant digits:
- ACTUAL: 1.234567e-021
- DESIRED: 1.2345672000000001e-021
+ ACTUAL: 1.234567e-21
+ DESIRED: 1.2345672e-21
the evaluated condition that raises the exception is
@@ -659,10 +665,10 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
sc_actual = actual/scale
except ZeroDivisionError:
sc_actual = 0.0
- msg = build_err_msg([actual, desired], err_msg,
- header='Items are not equal to %d significant digits:' %
- significant,
- verbose=verbose)
+ msg = build_err_msg(
+ [actual, desired], err_msg,
+ header='Items are not equal to %d significant digits:' % significant,
+ verbose=verbose)
try:
# If one of desired/actual is not finite, handle it specially here:
# check that both are nan if any is a nan, and test for equality
@@ -685,7 +691,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, inf, bool_
+ from numpy.core import array, array2string, isnan, inf, bool_, errstate
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -781,15 +787,31 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
reduced = val.ravel()
cond = reduced.all()
reduced = reduced.tolist()
+
# The below comparison is a hack to ensure that fully masked
# results, for which val.ravel().all() returns np.ma.masked,
# do not trigger a failure (np.ma.masked != True evaluates as
# np.ma.masked, which is falsy).
if cond != True:
mismatch = 100.0 * reduced.count(0) / ox.size
- msg = build_err_msg([ox, oy],
- err_msg
- + '\n(mismatch %s%%)' % (mismatch,),
+ remarks = ['Mismatch: {:.3g}%'.format(mismatch)]
+
+ with errstate(invalid='ignore', divide='ignore'):
+ # ignore errors for non-numeric types
+ with contextlib.suppress(TypeError):
+ error = abs(x - y)
+ max_abs_error = error.max()
+ remarks.append('Max absolute difference: '
+ + array2string(max_abs_error))
+
+ # note: this definition of relative error matches that one
+ # used by assert_allclose (found in np.isclose)
+ max_rel_error = (error / abs(y)).max()
+ remarks.append('Max relative difference: '
+ + array2string(max_rel_error))
+
+ err_msg += '\n' + '\n'.join(remarks)
+ msg = build_err_msg([ox, oy], err_msg,
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
@@ -849,13 +871,15 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
... [1, np.sqrt(np.pi)**2, np.nan])
- ...
- <type 'exceptions.ValueError'>:
+ Traceback (most recent call last):
+ ...
AssertionError:
Arrays are not equal
- (mismatch 50.0%)
- x: array([ 1. , 3.14159265, NaN])
- y: array([ 1. , 3.14159265, NaN])
+ Mismatch: 33.3%
+ Max absolute difference: 4.4408921e-16
+ Max relative difference: 1.41357986e-16
+ x: array([1. , 3.141593, nan])
+ y: array([1. , 3.141593, nan])
Use `assert_allclose` or one of the nulp (number of floating point values)
functions for these cases instead:
@@ -920,25 +944,29 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
the first assert does not raise an exception
>>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
- [1.0,2.333,np.nan])
+ ... [1.0,2.333,np.nan])
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
... [1.0,2.33339,np.nan], decimal=5)
- ...
- <type 'exceptions.AssertionError'>:
+ Traceback (most recent call last):
+ ...
AssertionError:
- Arrays are not almost equal
- (mismatch 50.0%)
- x: array([ 1. , 2.33333, NaN])
- y: array([ 1. , 2.33339, NaN])
+ Arrays are not almost equal to 5 decimals
+ Mismatch: 33.3%
+ Max absolute difference: 6.e-05
+ Max relative difference: 2.57136612e-05
+ x: array([1. , 2.33333, nan])
+ y: array([1. , 2.33339, nan])
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
... [1.0,2.33333, 5], decimal=5)
- <type 'exceptions.ValueError'>:
- ValueError:
- Arrays are not almost equal
- x: array([ 1. , 2.33333, NaN])
- y: array([ 1. , 2.33333, 5. ])
+ Traceback (most recent call last):
+ ...
+ AssertionError:
+ Arrays are not almost equal to 5 decimals
+ x and y nan location mismatch:
+ x: array([1. , 2.33333, nan])
+ y: array([1. , 2.33333, 5. ])
"""
__tracebackhide__ = True # Hide traceback for py.test
@@ -1019,27 +1047,34 @@ def assert_array_less(x, y, err_msg='', verbose=True):
--------
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
- ...
- <type 'exceptions.ValueError'>:
+ Traceback (most recent call last):
+ ...
+ AssertionError:
Arrays are not less-ordered
- (mismatch 50.0%)
- x: array([ 1., 1., NaN])
- y: array([ 1., 2., NaN])
+ Mismatch: 33.3%
+ Max absolute difference: 1.
+ Max relative difference: 0.5
+ x: array([ 1., 1., nan])
+ y: array([ 1., 2., nan])
>>> np.testing.assert_array_less([1.0, 4.0], 3)
- ...
- <type 'exceptions.ValueError'>:
+ Traceback (most recent call last):
+ ...
+ AssertionError:
Arrays are not less-ordered
- (mismatch 50.0%)
- x: array([ 1., 4.])
+ Mismatch: 50%
+ Max absolute difference: 2.
+ Max relative difference: 0.66666667
+ x: array([1., 4.])
y: array(3)
>>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
- ...
- <type 'exceptions.ValueError'>:
+ Traceback (most recent call last):
+ ...
+ AssertionError:
Arrays are not less-ordered
(shapes (3,), (1,) mismatch)
- x: array([ 1., 2., 3.])
+ x: array([1., 2., 3.])
y: array([4])
"""
@@ -1144,7 +1179,7 @@ def rundocs(filename=None, raise_on_error=True):
argument to the ``test()`` call. For example, to run all tests (including
doctests) for `numpy.lib`:
- >>> np.lib.test(doctests=True) #doctest: +SKIP
+ >>> np.lib.test(doctests=True) # doctest: +SKIP
"""
from numpy.compat import npy_load_module
import doctest
@@ -1326,7 +1361,7 @@ def decorate_methods(cls, decorator, testmatch=None):
return
-def measure(code_str,times=1,label=None):
+def measure(code_str, times=1, label=None):
"""
Return elapsed time for executing code in the namespace of the caller.
@@ -1353,9 +1388,9 @@ def measure(code_str,times=1,label=None):
Examples
--------
- >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)',
- ... times=times)
- >>> print("Time for a single execution : ", etime / times, "s")
+ >>> times = 10
+ >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times)
+ >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP
Time for a single execution : 0.005 s
"""
@@ -1440,7 +1475,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
--------
>>> x = [1e-5, 1e-3, 1e-1]
>>> y = np.arccos(np.cos(x))
- >>> assert_allclose(x, y, rtol=1e-5, atol=0)
+ >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
"""
__tracebackhide__ = True # Hide traceback for py.test
@@ -1894,7 +1929,8 @@ class clear_and_catch_warnings(warnings.catch_warnings):
Examples
--------
>>> import warnings
- >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]):
+ >>> with np.testing.clear_and_catch_warnings(
+ ... modules=[np.core.fromnumeric]):
... warnings.simplefilter('always')
... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
... # do something that raises a warning but ignore those in
@@ -1975,25 +2011,28 @@ class suppress_warnings(object):
Examples
--------
- >>> with suppress_warnings() as sup:
- ... sup.filter(DeprecationWarning, "Some text")
- ... sup.filter(module=np.ma.core)
- ... log = sup.record(FutureWarning, "Does this occur?")
- ... command_giving_warnings()
- ... # The FutureWarning was given once, the filtered warnings were
- ... # ignored. All other warnings abide outside settings (may be
- ... # printed/error)
- ... assert_(len(log) == 1)
- ... assert_(len(sup.log) == 1) # also stored in log attribute
-
- Or as a decorator:
-
- >>> sup = suppress_warnings()
- >>> sup.filter(module=np.ma.core) # module must match exact
- >>> @sup
- >>> def some_function():
- ... # do something which causes a warning in np.ma.core
- ... pass
+
+ With a context manager::
+
+ with np.testing.suppress_warnings() as sup:
+ sup.filter(DeprecationWarning, "Some text")
+ sup.filter(module=np.ma.core)
+ log = sup.record(FutureWarning, "Does this occur?")
+ command_giving_warnings()
+ # The FutureWarning was given once, the filtered warnings were
+ # ignored. All other warnings abide outside settings (may be
+ # printed/error)
+ assert_(len(log) == 1)
+ assert_(len(sup.log) == 1) # also stored in log attribute
+
+ Or as a decorator::
+
+ sup = np.testing.suppress_warnings()
+ sup.filter(module=np.ma.core) # module must match exactly
+ @sup
+ def some_function():
+ # do something which causes a warning in np.ma.core
+ pass
"""
def __init__(self, forwarding_rule="always"):
self._entered = False
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 43afafaa8..c376a3852 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -327,24 +327,22 @@ class TestEqual(TestArrayEqual):
self._test_not_equal(x, y)
def test_error_message(self):
- try:
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
- except AssertionError as e:
- msg = str(e)
- msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
- msg_reference = textwrap.dedent("""\
+ msg = str(exc_info.value)
+ msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
+ msg_reference = textwrap.dedent("""\
- Arrays are not equal
+ Arrays are not equal
- (shapes (2,), (1, 2) mismatch)
- x: array([1, 2])
- y: array([[1, 2]])""")
- try:
- assert_equal(msg, msg_reference)
- except AssertionError:
- assert_equal(msg2, msg_reference)
- else:
- raise AssertionError("Did not raise")
+ (shapes (2,), (1, 2) mismatch)
+ x: array([1, 2])
+ y: array([[1, 2]])""")
+
+ try:
+ assert_equal(msg, msg_reference)
+ except AssertionError:
+ assert_equal(msg2, msg_reference)
class TestArrayAlmostEqual(_GenericTest):
@@ -509,38 +507,53 @@ class TestAlmostEqual(_GenericTest):
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])
- # test with a different amount of decimal digits
- # note that we only check for the formatting of the arrays themselves
- b = ('x: array([1.00000000001, 2.00000000002, 3.00003 '
- ' ])\n y: array([1.00000000002, 2.00000000003, 3.00004 ])')
- try:
+ # Test with a different amount of decimal digits
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y, decimal=12)
- except AssertionError as e:
- # remove anything that's not the array string
- assert_equal(str(e).split('%)\n ')[1], b)
-
- # with the default value of decimal digits, only the 3rd element differs
- # note that we only check for the formatting of the arrays themselves
- b = ('x: array([1. , 2. , 3.00003])\n y: array([1. , '
- '2. , 3.00004])')
- try:
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatch: 100%')
+ assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
+ assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
+ assert_equal(
+ msgs[6],
+ ' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
+ assert_equal(
+ msgs[7],
+ ' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
+
+ # With the default value of decimal digits, only the 3rd element
+ # differs. Note that we only check for the formatting of the arrays
+ # themselves.
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y)
- except AssertionError as e:
- # remove anything that's not the array string
- assert_equal(str(e).split('%)\n ')[1], b)
-
- # Check the error message when input includes inf or nan
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatch: 33.3%')
+ assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
+ assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
+ assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
+ assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
+
+ # Check the error message when input includes inf
x = np.array([np.inf, 0])
y = np.array([np.inf, 1])
- try:
+ with pytest.raises(AssertionError) as exc_info:
+ self._assert_func(x, y)
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatch: 50%')
+ assert_equal(msgs[4], 'Max absolute difference: 1.')
+ assert_equal(msgs[5], 'Max relative difference: 1.')
+ assert_equal(msgs[6], ' x: array([inf, 0.])')
+ assert_equal(msgs[7], ' y: array([inf, 1.])')
+
+ # Check the error message when dividing by zero
+ x = np.array([1, 2])
+ y = np.array([0, 0])
+ with pytest.raises(AssertionError) as exc_info:
self._assert_func(x, y)
- except AssertionError as e:
- msgs = str(e).split('\n')
- # assert error percentage is 50%
- assert_equal(msgs[3], '(mismatch 50.0%)')
- # assert output array contains inf
- assert_equal(msgs[4], ' x: array([inf, 0.])')
- assert_equal(msgs[5], ' y: array([inf, 1.])')
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatch: 100%')
+ assert_equal(msgs[4], 'Max absolute difference: 2')
+ assert_equal(msgs[5], 'Max relative difference: inf')
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
@@ -829,12 +842,12 @@ class TestAssertAllclose(object):
def test_report_fail_percentage(self):
a = np.array([1, 1, 1, 1])
b = np.array([1, 1, 1, 2])
- try:
+
+ with pytest.raises(AssertionError) as exc_info:
assert_allclose(a, b)
- msg = ''
- except AssertionError as exc:
- msg = exc.args[0]
- assert_("mismatch 25.0%" in msg)
+ msg = str(exc_info.value)
+ assert_('Mismatch: 25%\nMax absolute difference: 1\n'
+ 'Max relative difference: 0.5' in msg)
def test_equal_nan(self):
a = np.array([np.nan])
@@ -1117,12 +1130,10 @@ class TestStringEqual(object):
assert_string_equal("hello", "hello")
assert_string_equal("hello\nmultiline", "hello\nmultiline")
- try:
+ with pytest.raises(AssertionError) as exc_info:
assert_string_equal("foo\nbar", "hello\nbar")
- except AssertionError as exc:
- assert_equal(str(exc), "Differences in strings:\n- foo\n+ hello")
- else:
- raise AssertionError("exception not raised")
+ msg = str(exc_info.value)
+ assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
assert_raises(AssertionError,
lambda: assert_string_equal("foo", "hello"))