diff options
author | Stephan Hoyer <shoyer@climate.com> | 2016-01-14 22:06:15 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@climate.com> | 2016-01-14 22:12:56 -0800 |
commit | d588b48a0e2fd4a78cadc1336571f59ba6be83c6 (patch) | |
tree | 3e1b59a65862916c40385ddb998b6df6e476ae38 /numpy/testing/utils.py | |
parent | aa6335c494e4807d65404d91e0e9d25a7d2fe338 (diff) | |
download | numpy-d588b48a0e2fd4a78cadc1336571f59ba6be83c6.tar.gz |
TST: Make assert_warns an optional contextmanager
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 8e71a3399..72105ca31 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -1706,7 +1706,22 @@ class WarningManager(object): self._module.showwarning = self._showwarning -def assert_warns(warning_class, func, *args, **kw): +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter('always') + yield + if not len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("No warning raised" + name_str) + if not l[0].category is warning_class: + name_str = "%s " % name if name is not None else "" + raise AssertionError("First warning %sis not a %s (is %s)" + % (name_str, warning_class, l[0])) + + +def assert_warns(warning_class, *args, **kwargs): """ Fail unless the given callable throws the specified warning. @@ -1715,6 +1730,12 @@ def assert_warns(warning_class, func, *args, **kw): If a different type of warning is thrown, it will not be caught, and the test case will be deemed to have suffered an error. + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + .. versionadded:: 1.4.0 Parameters @@ -1733,22 +1754,35 @@ def assert_warns(warning_class, func, *args, **kw): The value returned by `func`. """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): __tracebackhide__ = True # Hide traceback for py.test with warnings.catch_warnings(record=True) as l: warnings.simplefilter('always') - result = func(*args, **kw) - if not len(l) > 0: - raise AssertionError("No warning raised when calling %s" - % func.__name__) - if not l[0].category is warning_class: - raise AssertionError("First warning for %s is not a " - "%s( is %s)" % (func.__name__, warning_class, l[0])) - return result + yield + if len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("Got warnings%s: %s" % (name_str, l)) + -def assert_no_warnings(func, *args, **kw): +def assert_no_warnings(*args, **kwargs): """ Fail if the given callable produces any warnings. + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + .. versionadded:: 1.7.0 Parameters @@ -1765,14 +1799,13 @@ def assert_no_warnings(func, *args, **kw): The value returned by `func`. """ - __tracebackhide__ = True # Hide traceback for py.test - with warnings.catch_warnings(record=True) as l: - warnings.simplefilter('always') - result = func(*args, **kw) - if len(l) > 0: - raise AssertionError("Got warnings when calling %s: %s" - % (func.__name__, l)) - return result + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) def _gen_alignment_data(dtype=float32, type='binary', max_size=24): |