summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py950
1 files changed, 733 insertions, 217 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 3b20f9238..e2162acf9 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -9,32 +9,69 @@ import sys
import re
import operator
import warnings
-from functools import partial
+from functools import partial, wraps
import shutil
import contextlib
-from tempfile import mkdtemp
-from .nosetester import import_nose
+from tempfile import mkdtemp, mkstemp
+from unittest.case import SkipTest
+
from numpy.core import float32, empty, arange, array_repr, ndarray
+from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
from io import StringIO
else:
from StringIO import StringIO
-__all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
- 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
- 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
- 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
- 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
- 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
- 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
- 'assert_allclose', 'IgnoreException']
+__all__ = [
+ 'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
+ 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
+ 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
+ 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
+ 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
+ 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
+ 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
+ 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
+ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
+ 'HAS_REFCOUNT', 'suppress_warnings'
+ ]
+
+class KnownFailureException(Exception):
+ '''Raise this exception to mark a test as a known failing test.'''
+ pass
+
+KnownFailureTest = KnownFailureException # backwards compat
verbose = 0
+IS_PYPY = '__pypy__' in sys.modules
+HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None
+
+
+def import_nose():
+ """ Import nose only when needed.
+ """
+ nose_is_good = True
+ minimum_nose_version = (1, 0, 0)
+ try:
+ import nose
+ except ImportError:
+ nose_is_good = False
+ else:
+ if nose.__versioninfo__ < minimum_nose_version:
+ nose_is_good = False
+
+ if not nose_is_good:
+ msg = ('Need nose >= %d.%d.%d for tests - see '
+ 'http://nose.readthedocs.io' %
+ minimum_nose_version)
+ raise ImportError(msg)
+
+ return nose
-def assert_(val, msg='') :
+
+def assert_(val, msg=''):
"""
Assert that works in release mode.
Accepts callable msg to allow deferring evaluation until failure.
@@ -45,13 +82,14 @@ def assert_(val, msg='') :
For documentation on usage, refer to the Python documentation.
"""
- if not val :
+ if not val:
try:
smsg = msg()
except TypeError:
smsg = msg
raise AssertionError(smsg)
+
def gisnan(x):
"""like isnan, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -69,6 +107,7 @@ def gisnan(x):
raise TypeError("isnan not supported for this type")
return st
+
def gisfinite(x):
"""like isfinite, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -87,6 +126,7 @@ def gisfinite(x):
raise TypeError("isfinite not supported for this type")
return st
+
def gisinf(x):
"""like isinf, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -105,6 +145,9 @@ def gisinf(x):
raise TypeError("isinf not supported for this type")
return st
+
+@deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. "
+ "Use numpy.random.rand instead.")
def rand(*args):
"""Returns an array of random numbers with the given shape.
@@ -118,51 +161,11 @@ def rand(*args):
f[i] = random.random()
return results
-if sys.platform[:5]=='linux':
- def jiffies(_proc_pid_stat = '/proc/%s/stat'%(os.getpid()),
- _load_time=[]):
- """ Return number of jiffies (1/100ths of a second) that this
- process has been scheduled in user mode. See man 5 proc. """
- import time
- if not _load_time:
- _load_time.append(time.time())
- try:
- f=open(_proc_pid_stat, 'r')
- l = f.readline().split(' ')
- f.close()
- return int(l[13])
- except:
- return int(100*(time.time()-_load_time[0]))
-
- def memusage(_proc_pid_stat = '/proc/%s/stat'%(os.getpid())):
- """ Return virtual memory size in bytes of the running python.
- """
- try:
- f=open(_proc_pid_stat, 'r')
- l = f.readline().split(' ')
- f.close()
- return int(l[22])
- except:
- return
-else:
- # os.getpid is not in all platforms available.
- # Using time is safe but inaccurate, especially when process
- # was suspended or sleeping.
- def jiffies(_load_time=[]):
- """ Return number of jiffies (1/100ths of a second) that this
- process has been scheduled in user mode. [Emulation with time.time]. """
- import time
- if not _load_time:
- _load_time.append(time.time())
- return int(100*(time.time()-_load_time[0]))
- def memusage():
- """ Return memory usage of running python. [Not implemented]"""
- raise NotImplementedError
-if os.name=='nt':
+if os.name == 'nt':
# Code "stolen" from enthought/debug/memusage.py
- def GetPerformanceAttributes(object, counter, instance = None,
- inum=-1, format = None, machine=None):
+ def GetPerformanceAttributes(object, counter, instance=None,
+ inum=-1, format=None, machine=None):
# NOTE: Many counters require 2 samples to give accurate results,
# including "% Processor Time" (as by definition, at any instant, a
# thread's CPU usage is either 0 or 100). To read counters like this,
@@ -172,8 +175,9 @@ if os.name=='nt':
# My older explanation for this was that the "AddCounter" process forced
# the CPU to 100%, but the above makes more sense :)
import win32pdh
- if format is None: format = win32pdh.PDH_FMT_LONG
- path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter) )
+ if format is None:
+ format = win32pdh.PDH_FMT_LONG
+ path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter))
hq = win32pdh.OpenQuery()
try:
hc = win32pdh.AddCounter(hq, path)
@@ -192,6 +196,66 @@ if os.name=='nt':
return GetPerformanceAttributes("Process", "Virtual Bytes",
processName, instance,
win32pdh.PDH_FMT_LONG, None)
+elif sys.platform[:5] == 'linux':
+
+ def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())):
+ """
+ Return virtual memory size in bytes of the running python.
+
+ """
+ try:
+ f = open(_proc_pid_stat, 'r')
+ l = f.readline().split(' ')
+ f.close()
+ return int(l[22])
+ except:
+ return
+else:
+ def memusage():
+ """
+ Return memory usage of running python. [Not implemented]
+
+ """
+ raise NotImplementedError
+
+
+if sys.platform[:5] == 'linux':
+ def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()),
+ _load_time=[]):
+ """
+ Return number of jiffies elapsed.
+
+ Return number of jiffies (1/100ths of a second) that this
+ process has been scheduled in user mode. See man 5 proc.
+
+ """
+ import time
+ if not _load_time:
+ _load_time.append(time.time())
+ try:
+ f = open(_proc_pid_stat, 'r')
+ l = f.readline().split(' ')
+ f.close()
+ return int(l[13])
+ except:
+ return int(100*(time.time()-_load_time[0]))
+else:
+ # os.getpid is not in all platforms available.
+ # Using time is safe but inaccurate, especially when process
+ # was suspended or sleeping.
+ def jiffies(_load_time=[]):
+ """
+ Return number of jiffies elapsed.
+
+ Return number of jiffies (1/100ths of a second) that this
+ process has been scheduled in user mode. See man 5 proc.
+
+ """
+ import time
+ if not _load_time:
+ _load_time.append(time.time())
+ return int(100*(time.time()-_load_time[0]))
+
def build_err_msg(arrays, err_msg, header='Items are not equal:',
verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
@@ -212,14 +276,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
try:
r = r_func(a)
- except:
- r = '[repr failed]'
+ except Exception as exc:
+ r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
if r.count('\n') > 3:
r = '\n'.join(r.splitlines()[:3])
r += '...'
msg.append(' %s: %s' % (names[i], r))
return '\n'.join(msg)
+
def assert_equal(actual,desired,err_msg='',verbose=True):
"""
Raises an AssertionError if two objects are not equal.
@@ -255,12 +320,13 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
DESIRED: 6
"""
+ __tracebackhide__ = True # Hide traceback for py.test
if isinstance(desired, dict):
- if not isinstance(actual, dict) :
+ if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
assert_equal(len(actual), len(desired), err_msg, verbose)
for k, i in desired.items():
- if k not in actual :
+ if k not in actual:
raise AssertionError(repr(k))
assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
return
@@ -330,8 +396,13 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
pass
# Explicitly use __eq__ for comparison, ticket #2552
- if not (desired == actual):
- raise AssertionError(msg)
+ with suppress_warnings() as sup:
+ # TODO: Better handling will to needed when change happens!
+ sup.filter(DeprecationWarning, ".*NAT ==")
+ sup.filter(FutureWarning, ".*NAT ==")
+ if not (desired == actual):
+ raise AssertionError(msg)
+
def print_assert_equal(test_string, actual, desired):
"""
@@ -361,6 +432,7 @@ def print_assert_equal(test_string, actual, desired):
[0, 2]
"""
+ __tracebackhide__ = True # Hide traceback for py.test
import pprint
if not (actual == desired):
@@ -372,6 +444,7 @@ def print_assert_equal(test_string, actual, desired):
pprint.pprint(desired, msg)
raise AssertionError(msg.getvalue())
+
def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
"""
Raises an AssertionError if two items are not equal up to desired
@@ -382,11 +455,14 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
+ The test verifies that the elements of ``actual`` and ``desired`` satisfy.
- Given two objects (numbers or ndarrays), check that all elements of these
- objects are almost equal. An exception is raised at conflicting values.
- For ndarrays this delegates to assert_array_almost_equal
+ ``abs(desired-actual) < 1.5 * 10**(-decimal)``
+
+ That is a looser test than originally documented, but agrees with what the
+ actual implementation in `assert_array_almost_equal` did up to rounding
+ vagaries. An exception is raised at conflicting values. For ndarrays this
+ delegates to assert_array_almost_equal
Parameters
----------
@@ -434,6 +510,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
y: array([ 1. , 2.33333334])
"""
+ __tracebackhide__ = True # Hide traceback for py.test
from numpy.core import ndarray
from numpy.lib import iscomplexobj, real, imag
@@ -486,7 +563,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
return
except (NotImplementedError, TypeError):
pass
- if round(abs(desired - actual), decimal) != 0 :
+ if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
raise AssertionError(_build_err_msg())
@@ -547,10 +624,11 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
True
"""
+ __tracebackhide__ = True # Hide traceback for py.test
import numpy as np
(actual, desired) = map(float, (actual, desired))
- if desired==actual:
+ if desired == actual:
return
# Normalized the numbers to be in range (-10.0,10.0)
# scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
@@ -583,15 +661,41 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
return
except (TypeError, NotImplementedError):
pass
- if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)) :
+ if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
raise AssertionError(msg)
+
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header='', precision=6):
+ header='', precision=6, equal_nan=True):
+ __tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ def safe_comparison(*args, **kwargs):
+ # There are a number of cases where comparing two arrays hits special
+ # cases in array_richcompare, specifically around strings and void
+ # dtypes. Basically, we just can't do comparisons involving these
+ # types, unless both arrays have exactly the *same* type. So
+ # e.g. you can apply == to two string arrays, or two arrays with
+ # identical structured dtypes. But if you compare a non-string array
+ # to a string array, or two arrays with non-identical structured
+ # dtypes, or anything like that, then internally stuff blows up.
+ # Currently, when things blow up, we just return a scalar False or
+ # True. But we also emit a DeprecationWarning, b/c eventually we
+ # should raise an error here. (Ideally we might even make this work
+ # properly, but since that will require rewriting a bunch of how
+ # ufuncs work then we are not counting on that.)
+ #
+ # The point of this little function is to let the DeprecationWarning
+ # pass (or maybe eventually catch the errors and return False, I
+ # dunno, that's a little trickier and we can figure that out when the
+ # time comes).
+ with suppress_warnings() as sup:
+ sup.filter(DeprecationWarning, ".*==")
+ sup.filter(FutureWarning, ".*==")
+ return comparison(*args, **kwargs)
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -602,13 +706,13 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
assert_array_equal(x_id, y_id)
except AssertionError:
msg = build_err_msg([x, y],
- err_msg + '\nx and y %s location mismatch:' \
+ err_msg + '\nx and y %s location mismatch:'
% (hasval), verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
try:
- cond = (x.shape==() or y.shape==()) or x.shape == y.shape
+ cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if not cond:
msg = build_err_msg([x, y],
err_msg
@@ -616,36 +720,40 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
y.shape),
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
- if not cond :
+ if not cond:
raise AssertionError(msg)
if isnumber(x) and isnumber(y):
- x_isnan, y_isnan = isnan(x), isnan(y)
+ if equal_nan:
+ x_isnan, y_isnan = isnan(x), isnan(y)
+ # Validate that NaNs are in the same place
+ if any(x_isnan) or any(y_isnan):
+ chk_same_position(x_isnan, y_isnan, hasval='nan')
+
x_isinf, y_isinf = isinf(x), isinf(y)
- # Validate that the special values are in the same place
- if any(x_isnan) or any(y_isnan):
- chk_same_position(x_isnan, y_isnan, hasval='nan')
+ # Validate that infinite values are in the same place
if any(x_isinf) or any(y_isinf):
# Check +inf and -inf separately, since they are different
chk_same_position(x == +inf, y == +inf, hasval='+inf')
chk_same_position(x == -inf, y == -inf, hasval='-inf')
# Combine all the special values
- x_id, y_id = x_isnan, y_isnan
- x_id |= x_isinf
- y_id |= y_isinf
+ x_id, y_id = x_isinf, y_isinf
+ if equal_nan:
+ x_id |= x_isnan
+ y_id |= y_isnan
# Only do the comparison if actual values are left
if all(x_id):
return
if any(x_id):
- val = comparison(x[~x_id], y[~y_id])
+ val = safe_comparison(x[~x_id], y[~y_id])
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
if isinstance(val, bool):
cond = val
@@ -661,9 +769,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
- if not cond :
+ if not cond:
raise AssertionError(msg)
- except ValueError as e:
+ except ValueError:
import traceback
efmt = traceback.format_exc()
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
@@ -672,6 +780,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise ValueError(msg)
+
def assert_array_equal(x, y, err_msg='', verbose=True):
"""
Raises an AssertionError if two array_like objects are not equal.
@@ -738,6 +847,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
verbose=verbose, header='Arrays are not equal')
+
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
"""
Raises an AssertionError if two objects are not equal up to desired
@@ -748,14 +858,16 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test verifies identical shapes and verifies values with
- ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
+ The test verifies identical shapes and that the elements of ``actual`` and
+ ``desired`` satisfy.
- Given two array_like objects, check that the shape is equal and all
- elements of these objects are almost equal. An exception is raised at
- shape mismatch or conflicting values. In contrast to the standard usage
- in numpy, NaNs are compared like numbers, no assertion is raised if
- both objects have NaNs in the same positions.
+ ``abs(desired-actual) < 1.5 * 10**(-decimal)``
+
+ That is a looser test than originally documented, but agrees with what the
+ actual implementation did up to rounding vagaries. An exception is raised
+ at shape mismatch or conflicting values. In contrast to the standard usage
+ in numpy, NaNs are compared like numbers, no assertion is raised if both
+ objects have NaNs in the same positions.
Parameters
----------
@@ -808,9 +920,11 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y: array([ 1. , 2.33333, 5. ])
"""
+ __tracebackhide__ = True # Hide traceback for py.test
from numpy.core import around, number, float_, result_type, array
from numpy.core.numerictypes import issubdtype
from numpy.core.fromnumeric import any as npany
+
def compare(x, y):
try:
if npany(gisinf(x)) or npany( gisinf(y)):
@@ -830,12 +944,12 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
# casting of x later.
dtype = result_type(y, 1.)
y = array(y, dtype=dtype, copy=False, subok=True)
- z = abs(x-y)
+ z = abs(x - y)
if not issubdtype(z.dtype, number):
- z = z.astype(float_) # handle object arrays
+ z = z.astype(float_) # handle object arrays
- return around(z, decimal) <= 10.0**(-decimal)
+ return z < 1.5 * 10.0**(-decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header=('Arrays are not almost equal to %d decimals' % decimal),
@@ -908,13 +1022,16 @@ def assert_array_less(x, y, err_msg='', verbose=True):
y: array([4])
"""
+ __tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
verbose=verbose,
header='Arrays are not less-ordered')
+
def runstring(astr, dict):
exec(astr, dict)
+
def assert_string_equal(actual, desired):
"""
Test if two strings are equal.
@@ -942,9 +1059,10 @@ def assert_string_equal(actual, desired):
"""
# delay import of difflib to reduce startup time
+ __tracebackhide__ = True # Hide traceback for py.test
import difflib
- if not isinstance(actual, str) :
+ if not isinstance(actual, str):
raise AssertionError(repr(type(actual)))
if not isinstance(desired, str):
raise AssertionError(repr(type(desired)))
@@ -963,14 +1081,15 @@ def assert_string_equal(actual, desired):
if d2.startswith('? '):
l.append(d2)
d2 = diff.pop(0)
- if not d2.startswith('+ ') :
+ if not d2.startswith('+ '):
raise AssertionError(repr(d2))
l.append(d2)
- d3 = diff.pop(0)
- if d3.startswith('? '):
- l.append(d3)
- else:
- diff.insert(0, d3)
+ if diff:
+ d3 = diff.pop(0)
+ if d3.startswith('? '):
+ l.append(d3)
+ else:
+ diff.insert(0, d3)
if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]):
continue
diff_list.extend(l)
@@ -979,7 +1098,7 @@ def assert_string_equal(actual, desired):
if not diff_list:
return
msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
- if actual != desired :
+ if actual != desired:
raise AssertionError(msg)
@@ -1005,17 +1124,13 @@ def rundocs(filename=None, raise_on_error=True):
>>> np.lib.test(doctests=True) #doctest: +SKIP
"""
- import doctest, imp
+ from numpy.compat import npy_load_module
+ import doctest
if filename is None:
f = sys._getframe(1)
filename = f.f_globals['__file__']
name = os.path.splitext(os.path.basename(filename))[0]
- path = [os.path.dirname(filename)]
- file, pathname, description = imp.find_module(name, path)
- try:
- m = imp.load_module(name, file, pathname, description)
- finally:
- file.close()
+ m = npy_load_module(name, filename)
tests = doctest.DocTestFinder().find(m)
runner = doctest.DocTestRunner(verbose=False)
@@ -1038,9 +1153,10 @@ def raises(*args,**kwargs):
return nose.tools.raises(*args,**kwargs)
-def assert_raises(*args,**kwargs):
+def assert_raises(*args, **kwargs):
"""
assert_raises(exception_class, callable, *args, **kwargs)
+ assert_raises(exception_class)
Fail unless an exception of class exception_class is thrown
by callable when invoked with arguments args and keyword
@@ -1049,87 +1165,54 @@ def assert_raises(*args,**kwargs):
deemed to have suffered an error, exactly as for an
unexpected exception.
+ Alternatively, `assert_raises` can be used as a context manager:
+
+ >>> from numpy.testing import assert_raises
+ >>> with assert_raises(ZeroDivisionError):
+ ... 1 / 0
+
+ is equivalent to
+
+ >>> def div(x, y):
+ ... return x / y
+ >>> assert_raises(ZeroDivisionError, div, 1, 0)
+
"""
+ __tracebackhide__ = True # Hide traceback for py.test
nose = import_nose()
return nose.tools.assert_raises(*args,**kwargs)
-assert_raises_regex_impl = None
-
-
-def assert_raises_regex(exception_class, expected_regexp,
- callable_obj=None, *args, **kwargs):
+def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
"""
+ assert_raises_regex(exception_class, expected_regexp, callable, *args,
+ **kwargs)
+ assert_raises_regex(exception_class, expected_regexp)
+
Fail unless an exception of class exception_class and with message that
matches expected_regexp is thrown by callable when invoked with arguments
args and keyword arguments kwargs.
+ Alternatively, can be used as a context manager like `assert_raises`.
+
Name of this function adheres to Python 3.2+ reference, but should work in
all versions down to 2.6.
+ Notes
+ -----
+ .. versionadded:: 1.9.0
+
"""
+ __tracebackhide__ = True # Hide traceback for py.test
nose = import_nose()
- global assert_raises_regex_impl
- if assert_raises_regex_impl is None:
- try:
- # Python 3.2+
- assert_raises_regex_impl = nose.tools.assert_raises_regex
- except AttributeError:
- try:
- # 2.7+
- assert_raises_regex_impl = nose.tools.assert_raises_regexp
- except AttributeError:
- # 2.6
-
- # This class is copied from Python2.7 stdlib almost verbatim
- class _AssertRaisesContext(object):
- """A context manager used to implement TestCase.assertRaises* methods."""
-
- def __init__(self, expected, expected_regexp=None):
- self.expected = expected
- self.expected_regexp = expected_regexp
-
- def failureException(self, msg):
- return AssertionError(msg)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, tb):
- if exc_type is None:
- try:
- exc_name = self.expected.__name__
- except AttributeError:
- exc_name = str(self.expected)
- raise self.failureException(
- "{0} not raised".format(exc_name))
- if not issubclass(exc_type, self.expected):
- # let unexpected exceptions pass through
- return False
- self.exception = exc_value # store for later retrieval
- if self.expected_regexp is None:
- return True
-
- expected_regexp = self.expected_regexp
- if isinstance(expected_regexp, basestring):
- expected_regexp = re.compile(expected_regexp)
- if not expected_regexp.search(str(exc_value)):
- raise self.failureException(
- '"%s" does not match "%s"' %
- (expected_regexp.pattern, str(exc_value)))
- return True
-
- def impl(cls, regex, callable_obj, *a, **kw):
- mgr = _AssertRaisesContext(cls, regex)
- if callable_obj is None:
- return mgr
- with mgr:
- callable_obj(*a, **kw)
- assert_raises_regex_impl = impl
-
- return assert_raises_regex_impl(exception_class, expected_regexp,
- callable_obj, *args, **kwargs)
+ if sys.version_info.major >= 3:
+ funcname = nose.tools.assert_raises_regex
+ else:
+ # Only present in Python 2.7, missing from unittest in 2.6
+ funcname = nose.tools.assert_raises_regexp
+
+ return funcname(exception_class, expected_regexp, *args, **kwargs)
def decorate_methods(cls, decorator, testmatch=None):
@@ -1208,7 +1291,7 @@ def measure(code_str,times=1,label=None):
--------
>>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)',
... times=times)
- >>> print "Time for a single execution : ", etime / times, "s"
+ >>> print("Time for a single execution : ", etime / times, "s")
Time for a single execution : 0.005 s
"""
@@ -1226,25 +1309,28 @@ def measure(code_str,times=1,label=None):
elapsed = jiffies() - elapsed
return 0.01*elapsed
+
def _assert_valid_refcount(op):
"""
Check that ufuncs don't mishandle refcount of object `1`.
Used in a few regression tests.
"""
+ if not HAS_REFCOUNT:
+ return True
import numpy as np
- a = np.arange(100 * 100)
+
b = np.arange(100*100).reshape(100, 100)
c = b
-
i = 1
rc = sys.getrefcount(i)
for j in range(15):
d = op(b, c)
-
assert_(sys.getrefcount(i) >= rc)
+ del d # for pyflakes
-def assert_allclose(actual, desired, rtol=1e-7, atol=0,
+
+def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
err_msg='', verbose=True):
"""
Raises an AssertionError if two objects are not equal up to desired
@@ -1266,6 +1352,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0,
Relative tolerance.
atol : float, optional
Absolute tolerance.
+ equal_nan : bool, optional.
+ If True, NaNs will compare equal.
err_msg : str, optional
The error message to be printed in case of failure.
verbose : bool, optional
@@ -1287,14 +1375,18 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0,
>>> assert_allclose(x, y, rtol=1e-5, atol=0)
"""
+ __tracebackhide__ = True # Hide traceback for py.test
import numpy as np
+
def compare(x, y):
- return np.core.numeric._allclose_points(x, y, rtol=rtol, atol=atol)
+ return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
+ equal_nan=equal_nan)
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
- verbose=verbose, header=header)
+ verbose=verbose, header=header, equal_nan=equal_nan)
+
def assert_array_almost_equal_nulp(x, y, nulp=1):
"""
@@ -1331,7 +1423,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
-----
An assertion is raised if the following condition is not met::
- abs(x - y) <= nulps * spacing(max(abs(x), abs(y)))
+ abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y)))
Examples
--------
@@ -1345,6 +1437,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
AssertionError: X and Y are not equal to 1 ULP (max is 2)
"""
+ __tracebackhide__ = True # Hide traceback for py.test
import numpy as np
ax = np.abs(x)
ay = np.abs(y)
@@ -1357,6 +1450,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
raise AssertionError(msg)
+
def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
"""
Check that all items of arrays differ in at most N Units in the Last Place.
@@ -1393,13 +1487,15 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
>>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
"""
+ __tracebackhide__ = True # Hide traceback for py.test
import numpy as np
ret = nulp_diff(a, b, dtype)
if not np.all(ret <= maxulp):
- raise AssertionError("Arrays are not almost equal up to %g ULP" % \
+ raise AssertionError("Arrays are not almost equal up to %g ULP" %
maxulp)
return ret
+
def nulp_diff(x, y, dtype=None):
"""For each item in x and y, return the number of representable floating
points between them.
@@ -1410,6 +1506,8 @@ def nulp_diff(x, y, dtype=None):
first input array
y : array_like
second input array
+ dtype : dtype, optional
+ Data-type to convert `x` and `y` to if given. Default is None.
Returns
-------
@@ -1440,7 +1538,7 @@ def nulp_diff(x, y, dtype=None):
y = np.array(y, dtype=t)
if not x.shape == y.shape:
- raise ValueError("x and y do not have the same shape: %s - %s" % \
+ raise ValueError("x and y do not have the same shape: %s - %s" %
(x.shape, y.shape))
def _diff(rx, ry, vdt):
@@ -1451,6 +1549,7 @@ def nulp_diff(x, y, dtype=None):
ry = integer_repr(y)
return _diff(rx, ry, t)
+
def _integer_repr(x, vdt, comp):
# Reinterpret binary representation of the float as sign-magnitude:
# take into account two-complement representation
@@ -1458,13 +1557,14 @@ def _integer_repr(x, vdt, comp):
# http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm
rx = x.view(vdt)
if not (rx.size == 1):
- rx[rx < 0] = comp - rx[rx<0]
+ rx[rx < 0] = comp - rx[rx < 0]
else:
if rx < 0:
rx = comp - rx
return rx
+
def integer_repr(x):
"""Return the signed-magnitude interpretation of the binary representation of
x."""
@@ -1476,6 +1576,7 @@ def integer_repr(x):
else:
raise ValueError("Unsupported dtype %s" % x.dtype)
+
# The following two classes are copied from python 2.6 warnings module (context
# manager)
class WarningMessage(object):
@@ -1510,6 +1611,7 @@ class WarningMessage(object):
"line : %r}" % (self.message, self._category_name,
self.filename, self.lineno, self.line))
+
class WarningManager(object):
"""
A context manager that copies and restores the warnings filter upon
@@ -1534,6 +1636,7 @@ class WarningManager(object):
It is copied so it can be used in NumPy with older Python versions.
"""
+
def __init__(self, record=False, module=None):
self._record = record
if module is None:
@@ -1551,6 +1654,7 @@ class WarningManager(object):
self._showwarning = self._module.showwarning
if self._record:
log = []
+
def showwarning(*args, **kwargs):
log.append(WarningMessage(*args, **kwargs))
self._module.showwarning = showwarning
@@ -1565,14 +1669,32 @@ class WarningManager(object):
self._module.showwarning = self._showwarning
-def assert_warns(warning_class, func, *args, **kw):
+@contextlib.contextmanager
+def _assert_warns_context(warning_class, name=None):
+ __tracebackhide__ = True # Hide traceback for py.test
+ with suppress_warnings() as sup:
+ l = sup.record(warning_class)
+ yield
+ if not len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("No warning raised" + name_str)
+
+
+def assert_warns(warning_class, *args, **kwargs):
"""
Fail unless the given callable throws the specified warning.
A warning of class warning_class should be thrown by the callable when
invoked with arguments args and keyword arguments kwargs.
- If a different type of warning is thrown, it will not be caught, and the
- test case will be deemed to have suffered an error.
+ If a different type of warning is thrown, it will not be caught.
+
+ If called with all arguments other than the warning class omitted, may be
+ used as a context manager:
+
+ with assert_warns(SomeWarning):
+ do_something()
+
+ The ability to be used as a context manager is new in NumPy v1.11.0.
.. versionadded:: 1.4.0
@@ -1592,21 +1714,37 @@ def assert_warns(warning_class, func, *args, **kw):
The value returned by `func`.
"""
+ if not args:
+ return _assert_warns_context(warning_class)
+
+ func = args[0]
+ args = args[1:]
+ with _assert_warns_context(warning_class, name=func.__name__):
+ return func(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def _assert_no_warnings_context(name=None):
+ __tracebackhide__ = True # Hide traceback for py.test
with warnings.catch_warnings(record=True) as l:
warnings.simplefilter('always')
- result = func(*args, **kw)
- if not len(l) > 0:
- raise AssertionError("No warning raised when calling %s"
- % func.__name__)
- if not l[0].category is warning_class:
- raise AssertionError("First warning for %s is not a " \
- "%s( is %s)" % (func.__name__, warning_class, l[0]))
- return result
-
-def assert_no_warnings(func, *args, **kw):
+ yield
+ if len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("Got warnings%s: %s" % (name_str, l))
+
+
+def assert_no_warnings(*args, **kwargs):
"""
Fail if the given callable produces any warnings.
+ If called with all arguments omitted, may be used as a context manager:
+
+ with assert_no_warnings():
+ do_something()
+
+ The ability to be used as a context manager is new in NumPy v1.11.0.
+
.. versionadded:: 1.7.0
Parameters
@@ -1623,13 +1761,13 @@ def assert_no_warnings(func, *args, **kw):
The value returned by `func`.
"""
- with warnings.catch_warnings(record=True) as l:
- warnings.simplefilter('always')
- result = func(*args, **kw)
- if len(l) > 0:
- raise AssertionError("Got warnings when calling %s: %s"
- % (func.__name__, l))
- return result
+ if not args:
+ return _assert_no_warnings_context()
+
+ func = args[0]
+ args = args[1:]
+ with _assert_no_warnings_context(name=func.__name__):
+ return func(*args, **kwargs)
def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
@@ -1662,10 +1800,11 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
for o in range(3):
for s in range(o + 2, max(o + 3, max_size)):
if type == 'unary':
- inp = lambda : arange(s, dtype=dtype)[o:]
+ inp = lambda: arange(s, dtype=dtype)[o:]
out = empty((s,), dtype=dtype)[o:]
yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
- yield inp(), inp(), ufmt % (o, o, s, dtype, 'in place')
+ d = inp()
+ yield d, d, ufmt % (o, o, s, dtype, 'in place')
yield out[1:], inp()[:-1], ufmt % \
(o + 1, o, s - 1, dtype, 'out of place')
yield out[:-1], inp()[1:], ufmt % \
@@ -1675,14 +1814,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
yield inp()[1:], inp()[:-1], ufmt % \
(o + 1, o, s - 1, dtype, 'aliased')
if type == 'binary':
- inp1 = lambda :arange(s, dtype=dtype)[o:]
- inp2 = lambda :arange(s, dtype=dtype)[o:]
+ inp1 = lambda: arange(s, dtype=dtype)[o:]
+ inp2 = lambda: arange(s, dtype=dtype)[o:]
out = empty((s,), dtype=dtype)[o:]
yield out, inp1(), inp2(), bfmt % \
(o, o, o, s, dtype, 'out of place')
- yield inp1(), inp1(), inp2(), bfmt % \
+ d = inp1()
+ yield d, d, inp2(), bfmt % \
(o, o, o, s, dtype, 'in place1')
- yield inp2(), inp1(), inp2(), bfmt % \
+ d = inp2()
+ yield d, inp1(), d, bfmt % \
(o, o, o, s, dtype, 'in place2')
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
(o + 1, o, o, s - 1, dtype, 'out of place')
@@ -1711,5 +1852,380 @@ def tempdir(*args, **kwargs):
"""
tmpdir = mkdtemp(*args, **kwargs)
- yield tmpdir
- shutil.rmtree(tmpdir)
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir)
+
+
+@contextlib.contextmanager
+def temppath(*args, **kwargs):
+ """Context manager for temporary files.
+
+ Context manager that returns the path to a closed temporary file. Its
+ parameters are the same as for tempfile.mkstemp and are passed directly
+ to that function. The underlying file is removed when the context is
+ exited, so it should be closed at that time.
+
+ Windows does not allow a temporary file to be opened if it is already
+ open, so the underlying file must be closed after opening before it
+ can be opened again.
+
+ """
+ fd, path = mkstemp(*args, **kwargs)
+ os.close(fd)
+ try:
+ yield path
+ finally:
+ os.remove(path)
+
+
+class clear_and_catch_warnings(warnings.catch_warnings):
+ """ Context manager that resets warning registry for catching warnings
+
+ Warnings can be slippery, because, whenever a warning is triggered, Python
+ adds a ``__warningregistry__`` member to the *calling* module. This makes
+ it impossible to retrigger the warning in this module, whatever you put in
+ the warnings filters. This context manager accepts a sequence of `modules`
+ as a keyword argument to its constructor and:
+
+ * stores and removes any ``__warningregistry__`` entries in given `modules`
+ on entry;
+ * resets ``__warningregistry__`` to its previous state on exit.
+
+ This makes it possible to trigger any warning afresh inside the context
+ manager without disturbing the state of warnings outside.
+
+ For compatibility with Python 3.0, please consider all arguments to be
+ keyword-only.
+
+ Parameters
+ ----------
+ record : bool, optional
+ Specifies whether warnings should be captured by a custom
+ implementation of ``warnings.showwarning()`` and be appended to a list
+ returned by the context manager. Otherwise None is returned by the
+ context manager. The objects appended to the list are arguments whose
+ attributes mirror the arguments to ``showwarning()``.
+ modules : sequence, optional
+ Sequence of modules for which to reset warnings registry on entry and
+ restore on exit. To work correctly, all 'ignore' filters should
+ filter by one of these modules.
+
+ Examples
+ --------
+ >>> import warnings
+ >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]):
+ ... warnings.simplefilter('always')
+ ... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
+ ... # do something that raises a warning but ignore those in
+ ... # np.core.fromnumeric
+ """
+ class_modules = ()
+
+ def __init__(self, record=False, modules=()):
+ self.modules = set(modules).union(self.class_modules)
+ self._warnreg_copies = {}
+ super(clear_and_catch_warnings, self).__init__(record=record)
+
+ def __enter__(self):
+ for mod in self.modules:
+ if hasattr(mod, '__warningregistry__'):
+ mod_reg = mod.__warningregistry__
+ self._warnreg_copies[mod] = mod_reg.copy()
+ mod_reg.clear()
+ return super(clear_and_catch_warnings, self).__enter__()
+
+ def __exit__(self, *exc_info):
+ super(clear_and_catch_warnings, self).__exit__(*exc_info)
+ for mod in self.modules:
+ if hasattr(mod, '__warningregistry__'):
+ mod.__warningregistry__.clear()
+ if mod in self._warnreg_copies:
+ mod.__warningregistry__.update(self._warnreg_copies[mod])
+
+
+class suppress_warnings(object):
+ """
+ Context manager and decorator doing much the same as
+ ``warnings.catch_warnings``.
+
+ However, it also provides a filter mechanism to work around
+ http://bugs.python.org/issue4180.
+
+ This bug causes Python before 3.4 to not reliably show warnings again
+ after they have been ignored once (even within catch_warnings). It
+ means that no "ignore" filter can be used easily, since following
+ tests might need to see the warning. Additionally it allows easier
+ specificity for testing warnings and can be nested.
+
+ Parameters
+ ----------
+ forwarding_rule : str, optional
+ One of "always", "once", "module", or "location". Analogous to
+ the usual warnings module filter mode, it is useful to reduce
+ noise mostly on the outmost level. Unsuppressed and unrecorded
+ warnings will be forwarded based on this rule. Defaults to "always".
+ "location" is equivalent to the warnings "default", match by exact
+ location the warning warning originated from.
+
+ Notes
+ -----
+ Filters added inside the context manager will be discarded again
+ when leaving it. Upon entering all filters defined outside a
+ context will be applied automatically.
+
+ When a recording filter is added, matching warnings are stored in the
+ ``log`` attribute as well as in the list returned by ``record``.
+
+ If filters are added and the ``module`` keyword is given, the
+ warning registry of this module will additionally be cleared when
+ applying it, entering the context, or exiting it. This could cause
+ warnings to appear a second time after leaving the context if they
+ were configured to be printed once (default) and were already
+ printed before the context was entered.
+
+ Nesting this context manager will work as expected when the
+ forwarding rule is "always" (default). Unfiltered and unrecorded
+ warnings will be passed out and be matched by the outer level.
+ On the outmost level they will be printed (or caught by another
+ warnings context). The forwarding rule argument can modify this
+ behaviour.
+
+ Like ``catch_warnings`` this context manager is not threadsafe.
+
+ Examples
+ --------
+ >>> with suppress_warnings() as sup:
+ ... sup.filter(DeprecationWarning, "Some text")
+ ... sup.filter(module=np.ma.core)
+ ... log = sup.record(FutureWarning, "Does this occur?")
+ ... command_giving_warnings()
+ ... # The FutureWarning was given once, the filtered warnings were
+ ... # ignored. All other warnings abide outside settings (may be
+ ... # printed/error)
+ ... assert_(len(log) == 1)
+ ... assert_(len(sup.log) == 1) # also stored in log attribute
+
+ Or as a decorator:
+
+ >>> sup = suppress_warnings()
+ >>> sup.filter(module=np.ma.core) # module must match exact
+ >>> @sup
+ >>> def some_function():
+ ... # do something which causes a warning in np.ma.core
+ ... pass
+ """
+ def __init__(self, forwarding_rule="always"):
+ self._entered = False
+
+ # Suppressions are either instance or defined inside one with block:
+ self._suppressions = []
+
+ if forwarding_rule not in {"always", "module", "once", "location"}:
+ raise ValueError("unsupported forwarding rule.")
+ self._forwarding_rule = forwarding_rule
+
+ def _clear_registries(self):
+ if hasattr(warnings, "_filters_mutated"):
+ # clearing the registry should not be necessary on new pythons,
+ # instead the filters should be mutated.
+ warnings._filters_mutated()
+ return
+ # Simply clear the registry, this should normally be harmless,
+ # note that on new pythons it would be invalidated anyway.
+ for module in self._tmp_modules:
+ if hasattr(module, "__warningregistry__"):
+ module.__warningregistry__.clear()
+
+ def _filter(self, category=Warning, message="", module=None, record=False):
+ if record:
+ record = [] # The log where to store warnings
+ else:
+ record = None
+ if self._entered:
+ if module is None:
+ warnings.filterwarnings(
+ "always", category=category, message=message)
+ else:
+ module_regex = module.__name__.replace('.', '\.') + '$'
+ warnings.filterwarnings(
+ "always", category=category, message=message,
+ module=module_regex)
+ self._tmp_modules.add(module)
+ self._clear_registries()
+
+ self._tmp_suppressions.append(
+ (category, message, re.compile(message, re.I), module, record))
+ else:
+ self._suppressions.append(
+ (category, message, re.compile(message, re.I), module, record))
+
+ return record
+
+ def filter(self, category=Warning, message="", module=None):
+ """
+ Add a new suppressing filter or apply it if the state is entered.
+
+ Parameters
+ ----------
+ category : class, optional
+ Warning class to filter
+ message : string, optional
+ Regular expression matching the warning message.
+ module : module, optional
+ Module to filter for. Note that the module (and its file)
+ must match exactly and cannot be a submodule. This may make
+ it unreliable for external modules.
+
+ Notes
+ -----
+ When added within a context, filters are only added inside
+ the context and will be forgotten when the context is exited.
+ """
+ self._filter(category=category, message=message, module=module,
+ record=False)
+
+ def record(self, category=Warning, message="", module=None):
+ """
+ Append a new recording filter or apply it if the state is entered.
+
+ All warnings matching will be appended to the ``log`` attribute.
+
+ Parameters
+ ----------
+ category : class, optional
+ Warning class to filter
+ message : string, optional
+ Regular expression matching the warning message.
+ module : module, optional
+ Module to filter for. Note that the module (and its file)
+ must match exactly and cannot be a submodule. This may make
+ it unreliable for external modules.
+
+ Returns
+ -------
+ log : list
+ A list which will be filled with all matched warnings.
+
+ Notes
+ -----
+ When added within a context, filters are only added inside
+ the context and will be forgotten when the context is exited.
+ """
+ return self._filter(category=category, message=message, module=module,
+ record=True)
+
+ def __enter__(self):
+ if self._entered:
+ raise RuntimeError("cannot enter suppress_warnings twice.")
+
+ self._orig_show = warnings.showwarning
+ if hasattr(warnings, "_showwarnmsg"):
+ self._orig_showmsg = warnings._showwarnmsg
+ self._filters = warnings.filters
+ warnings.filters = self._filters[:]
+
+ self._entered = True
+ self._tmp_suppressions = []
+ self._tmp_modules = set()
+ self._forwarded = set()
+
+ self.log = [] # reset global log (no need to keep same list)
+
+ for cat, mess, _, mod, log in self._suppressions:
+ if log is not None:
+ del log[:] # clear the log
+ if mod is None:
+ warnings.filterwarnings(
+ "always", category=cat, message=mess)
+ else:
+ module_regex = mod.__name__.replace('.', '\.') + '$'
+ warnings.filterwarnings(
+ "always", category=cat, message=mess,
+ module=module_regex)
+ self._tmp_modules.add(mod)
+ warnings.showwarning = self._showwarning
+ if hasattr(warnings, "_showwarnmsg"):
+ warnings._showwarnmsg = self._showwarnmsg
+ self._clear_registries()
+
+ return self
+
+ def __exit__(self, *exc_info):
+ warnings.showwarning = self._orig_show
+ if hasattr(warnings, "_showwarnmsg"):
+ warnings._showwarnmsg = self._orig_showmsg
+ warnings.filters = self._filters
+ self._clear_registries()
+ self._entered = False
+ del self._orig_show
+ del self._filters
+
+ def _showwarnmsg(self, msg):
+ self._showwarning(msg.message, msg.category, msg.filename, msg.lineno,
+ msg.file, msg.line, use_warnmsg=msg)
+
+ def _showwarning(self, message, category, filename, lineno,
+ *args, **kwargs):
+ use_warnmsg = kwargs.pop("use_warnmsg", None)
+ for cat, _, pattern, mod, rec in (
+ self._suppressions + self._tmp_suppressions)[::-1]:
+ if (issubclass(category, cat) and
+ pattern.match(message.args[0]) is not None):
+ if mod is None:
+ # Message and category match, either recorded or ignored
+ if rec is not None:
+ msg = WarningMessage(message, category, filename,
+ lineno, **kwargs)
+ self.log.append(msg)
+ rec.append(msg)
+ return
+ # Use startswith, because warnings strips the c or o from
+ # .pyc/.pyo files.
+ elif mod.__file__.startswith(filename):
+ # The message and module (filename) match
+ if rec is not None:
+ msg = WarningMessage(message, category, filename,
+ lineno, **kwargs)
+ self.log.append(msg)
+ rec.append(msg)
+ return
+
+ # There is no filter in place, so pass to the outside handler
+ # unless we should only pass it once
+ if self._forwarding_rule == "always":
+ if use_warnmsg is None:
+ self._orig_show(message, category, filename, lineno,
+ *args, **kwargs)
+ else:
+ self._orig_showmsg(use_warnmsg)
+ return
+
+ if self._forwarding_rule == "once":
+ signature = (message.args, category)
+ elif self._forwarding_rule == "module":
+ signature = (message.args, category, filename)
+ elif self._forwarding_rule == "location":
+ signature = (message.args, category, filename, lineno)
+
+ if signature in self._forwarded:
+ return
+ self._forwarded.add(signature)
+ if use_warnmsg is None:
+ self._orig_show(message, category, filename, lineno, *args,
+ **kwargs)
+ else:
+ self._orig_showmsg(use_warnmsg)
+
+ def __call__(self, func):
+ """
+ Function decorator to apply certain suppressions to a whole
+ function.
+ """
+ @wraps(func)
+ def new_func(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+
+ return new_func