summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@climate.com>2016-01-14 22:06:15 -0800
committerStephan Hoyer <shoyer@climate.com>2016-01-14 22:12:56 -0800
commitd588b48a0e2fd4a78cadc1336571f59ba6be83c6 (patch)
tree3e1b59a65862916c40385ddb998b6df6e476ae38 /numpy/testing/utils.py
parentaa6335c494e4807d65404d91e0e9d25a7d2fe338 (diff)
downloadnumpy-d588b48a0e2fd4a78cadc1336571f59ba6be83c6.tar.gz
TST: Make assert_warns an optional contextmanager
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py69
1 files changed, 51 insertions, 18 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 8e71a3399..72105ca31 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -1706,7 +1706,22 @@ class WarningManager(object):
self._module.showwarning = self._showwarning
-def assert_warns(warning_class, func, *args, **kw):
+@contextlib.contextmanager
+def _assert_warns_context(warning_class, name=None):
+ __tracebackhide__ = True # Hide traceback for py.test
+ with warnings.catch_warnings(record=True) as l:
+ warnings.simplefilter('always')
+ yield
+ if not len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("No warning raised" + name_str)
+ if not l[0].category is warning_class:
+ name_str = "%s " % name if name is not None else ""
+ raise AssertionError("First warning %sis not a %s (is %s)"
+ % (name_str, warning_class, l[0]))
+
+
+def assert_warns(warning_class, *args, **kwargs):
"""
Fail unless the given callable throws the specified warning.
@@ -1715,6 +1730,12 @@ def assert_warns(warning_class, func, *args, **kw):
If a different type of warning is thrown, it will not be caught, and the
test case will be deemed to have suffered an error.
+ If called with all arguments other than the warning class omitted, may be
+ used as a context manager:
+
+ with assert_warns(SomeWarning):
+ do_something()
+
.. versionadded:: 1.4.0
Parameters
@@ -1733,22 +1754,35 @@ def assert_warns(warning_class, func, *args, **kw):
The value returned by `func`.
"""
+ if not args:
+ return _assert_warns_context(warning_class)
+
+ func = args[0]
+ args = args[1:]
+ with _assert_warns_context(warning_class, name=func.__name__):
+ return func(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def _assert_no_warnings_context(name=None):
__tracebackhide__ = True # Hide traceback for py.test
with warnings.catch_warnings(record=True) as l:
warnings.simplefilter('always')
- result = func(*args, **kw)
- if not len(l) > 0:
- raise AssertionError("No warning raised when calling %s"
- % func.__name__)
- if not l[0].category is warning_class:
- raise AssertionError("First warning for %s is not a "
- "%s( is %s)" % (func.__name__, warning_class, l[0]))
- return result
+ yield
+ if len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("Got warnings%s: %s" % (name_str, l))
+
-def assert_no_warnings(func, *args, **kw):
+def assert_no_warnings(*args, **kwargs):
"""
Fail if the given callable produces any warnings.
+ If called with all arguments omitted, may be used as a context manager:
+
+ with assert_no_warnings():
+ do_something()
+
.. versionadded:: 1.7.0
Parameters
@@ -1765,14 +1799,13 @@ def assert_no_warnings(func, *args, **kw):
The value returned by `func`.
"""
- __tracebackhide__ = True # Hide traceback for py.test
- with warnings.catch_warnings(record=True) as l:
- warnings.simplefilter('always')
- result = func(*args, **kw)
- if len(l) > 0:
- raise AssertionError("Got warnings when calling %s: %s"
- % (func.__name__, l))
- return result
+ if not args:
+ return _assert_no_warnings_context()
+
+ func = args[0]
+ args = args[1:]
+ with _assert_no_warnings_context(name=func.__name__):
+ return func(*args, **kwargs)
def _gen_alignment_data(dtype=float32, type='binary', max_size=24):