diff options
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 950 |
1 files changed, 733 insertions, 217 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 3b20f9238..e2162acf9 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,32 +9,69 @@ import sys import re import operator import warnings -from functools import partial +from functools import partial, wraps import shutil import contextlib -from tempfile import mkdtemp -from .nosetester import import_nose +from tempfile import mkdtemp, mkstemp +from unittest.case import SkipTest + from numpy.core import float32, empty, arange, array_repr, ndarray +from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: from io import StringIO else: from StringIO import StringIO -__all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', - 'assert_array_equal', 'assert_array_less', 'assert_string_equal', - 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', - 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', - 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', - 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', - 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', - 'assert_allclose', 'IgnoreException'] +__all__ = [ + 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', + 'assert_array_equal', 'assert_array_less', 'assert_string_equal', + 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', + 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', + 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', + 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', + 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', + 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', + 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', + 'HAS_REFCOUNT', 'suppress_warnings' + ] + +class KnownFailureException(Exception): + '''Raise this exception to mark a test as a known failing test.''' + pass + +KnownFailureTest = KnownFailureException # backwards compat verbose = 0 +IS_PYPY = '__pypy__' in sys.modules +HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None + + +def import_nose(): + """ Import nose only when needed. + """ + nose_is_good = True + minimum_nose_version = (1, 0, 0) + try: + import nose + except ImportError: + nose_is_good = False + else: + if nose.__versioninfo__ < minimum_nose_version: + nose_is_good = False + + if not nose_is_good: + msg = ('Need nose >= %d.%d.%d for tests - see ' + 'http://nose.readthedocs.io' % + minimum_nose_version) + raise ImportError(msg) + + return nose -def assert_(val, msg='') : + +def assert_(val, msg=''): """ Assert that works in release mode. Accepts callable msg to allow deferring evaluation until failure. @@ -45,13 +82,14 @@ def assert_(val, msg='') : For documentation on usage, refer to the Python documentation. """ - if not val : + if not val: try: smsg = msg() except TypeError: smsg = msg raise AssertionError(smsg) + def gisnan(x): """like isnan, but always raise an error if type not supported instead of returning a TypeError object. @@ -69,6 +107,7 @@ def gisnan(x): raise TypeError("isnan not supported for this type") return st + def gisfinite(x): """like isfinite, but always raise an error if type not supported instead of returning a TypeError object. @@ -87,6 +126,7 @@ def gisfinite(x): raise TypeError("isfinite not supported for this type") return st + def gisinf(x): """like isinf, but always raise an error if type not supported instead of returning a TypeError object. @@ -105,6 +145,9 @@ def gisinf(x): raise TypeError("isinf not supported for this type") return st + +@deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. " + "Use numpy.random.rand instead.") def rand(*args): """Returns an array of random numbers with the given shape. @@ -118,51 +161,11 @@ def rand(*args): f[i] = random.random() return results -if sys.platform[:5]=='linux': - def jiffies(_proc_pid_stat = '/proc/%s/stat'%(os.getpid()), - _load_time=[]): - """ Return number of jiffies (1/100ths of a second) that this - process has been scheduled in user mode. See man 5 proc. """ - import time - if not _load_time: - _load_time.append(time.time()) - try: - f=open(_proc_pid_stat, 'r') - l = f.readline().split(' ') - f.close() - return int(l[13]) - except: - return int(100*(time.time()-_load_time[0])) - - def memusage(_proc_pid_stat = '/proc/%s/stat'%(os.getpid())): - """ Return virtual memory size in bytes of the running python. - """ - try: - f=open(_proc_pid_stat, 'r') - l = f.readline().split(' ') - f.close() - return int(l[22]) - except: - return -else: - # os.getpid is not in all platforms available. - # Using time is safe but inaccurate, especially when process - # was suspended or sleeping. - def jiffies(_load_time=[]): - """ Return number of jiffies (1/100ths of a second) that this - process has been scheduled in user mode. [Emulation with time.time]. """ - import time - if not _load_time: - _load_time.append(time.time()) - return int(100*(time.time()-_load_time[0])) - def memusage(): - """ Return memory usage of running python. [Not implemented]""" - raise NotImplementedError -if os.name=='nt': +if os.name == 'nt': # Code "stolen" from enthought/debug/memusage.py - def GetPerformanceAttributes(object, counter, instance = None, - inum=-1, format = None, machine=None): + def GetPerformanceAttributes(object, counter, instance=None, + inum=-1, format=None, machine=None): # NOTE: Many counters require 2 samples to give accurate results, # including "% Processor Time" (as by definition, at any instant, a # thread's CPU usage is either 0 or 100). To read counters like this, @@ -172,8 +175,9 @@ if os.name=='nt': # My older explanation for this was that the "AddCounter" process forced # the CPU to 100%, but the above makes more sense :) import win32pdh - if format is None: format = win32pdh.PDH_FMT_LONG - path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter) ) + if format is None: + format = win32pdh.PDH_FMT_LONG + path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter)) hq = win32pdh.OpenQuery() try: hc = win32pdh.AddCounter(hq, path) @@ -192,6 +196,66 @@ if os.name=='nt': return GetPerformanceAttributes("Process", "Virtual Bytes", processName, instance, win32pdh.PDH_FMT_LONG, None) +elif sys.platform[:5] == 'linux': + + def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())): + """ + Return virtual memory size in bytes of the running python. + + """ + try: + f = open(_proc_pid_stat, 'r') + l = f.readline().split(' ') + f.close() + return int(l[22]) + except: + return +else: + def memusage(): + """ + Return memory usage of running python. [Not implemented] + + """ + raise NotImplementedError + + +if sys.platform[:5] == 'linux': + def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()), + _load_time=[]): + """ + Return number of jiffies elapsed. + + Return number of jiffies (1/100ths of a second) that this + process has been scheduled in user mode. See man 5 proc. + + """ + import time + if not _load_time: + _load_time.append(time.time()) + try: + f = open(_proc_pid_stat, 'r') + l = f.readline().split(' ') + f.close() + return int(l[13]) + except: + return int(100*(time.time()-_load_time[0])) +else: + # os.getpid is not in all platforms available. + # Using time is safe but inaccurate, especially when process + # was suspended or sleeping. + def jiffies(_load_time=[]): + """ + Return number of jiffies elapsed. + + Return number of jiffies (1/100ths of a second) that this + process has been scheduled in user mode. See man 5 proc. + + """ + import time + if not _load_time: + _load_time.append(time.time()) + return int(100*(time.time()-_load_time[0])) + def build_err_msg(arrays, err_msg, header='Items are not equal:', verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): @@ -212,14 +276,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', try: r = r_func(a) - except: - r = '[repr failed]' + except Exception as exc: + r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc) if r.count('\n') > 3: r = '\n'.join(r.splitlines()[:3]) r += '...' msg.append(' %s: %s' % (names[i], r)) return '\n'.join(msg) + def assert_equal(actual,desired,err_msg='',verbose=True): """ Raises an AssertionError if two objects are not equal. @@ -255,12 +320,13 @@ def assert_equal(actual,desired,err_msg='',verbose=True): DESIRED: 6 """ + __tracebackhide__ = True # Hide traceback for py.test if isinstance(desired, dict): - if not isinstance(actual, dict) : + if not isinstance(actual, dict): raise AssertionError(repr(type(actual))) assert_equal(len(actual), len(desired), err_msg, verbose) for k, i in desired.items(): - if k not in actual : + if k not in actual: raise AssertionError(repr(k)) assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose) return @@ -330,8 +396,13 @@ def assert_equal(actual,desired,err_msg='',verbose=True): pass # Explicitly use __eq__ for comparison, ticket #2552 - if not (desired == actual): - raise AssertionError(msg) + with suppress_warnings() as sup: + # TODO: Better handling will to needed when change happens! + sup.filter(DeprecationWarning, ".*NAT ==") + sup.filter(FutureWarning, ".*NAT ==") + if not (desired == actual): + raise AssertionError(msg) + def print_assert_equal(test_string, actual, desired): """ @@ -361,6 +432,7 @@ def print_assert_equal(test_string, actual, desired): [0, 2] """ + __tracebackhide__ = True # Hide traceback for py.test import pprint if not (actual == desired): @@ -372,6 +444,7 @@ def print_assert_equal(test_string, actual, desired): pprint.pprint(desired, msg) raise AssertionError(msg.getvalue()) + def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to desired @@ -382,11 +455,14 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): instead of this function for more consistent floating point comparisons. - The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``. + The test verifies that the elements of ``actual`` and ``desired`` satisfy. - Given two objects (numbers or ndarrays), check that all elements of these - objects are almost equal. An exception is raised at conflicting values. - For ndarrays this delegates to assert_array_almost_equal + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal Parameters ---------- @@ -434,6 +510,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): y: array([ 1. , 2.33333334]) """ + __tracebackhide__ = True # Hide traceback for py.test from numpy.core import ndarray from numpy.lib import iscomplexobj, real, imag @@ -486,7 +563,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): return except (NotImplementedError, TypeError): pass - if round(abs(desired - actual), decimal) != 0 : + if abs(desired - actual) >= 1.5 * 10.0**(-decimal): raise AssertionError(_build_err_msg()) @@ -547,10 +624,11 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): True """ + __tracebackhide__ = True # Hide traceback for py.test import numpy as np (actual, desired) = map(float, (actual, desired)) - if desired==actual: + if desired == actual: return # Normalized the numbers to be in range (-10.0,10.0) # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) @@ -583,15 +661,41 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): return except (TypeError, NotImplementedError): pass - if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)) : + if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): raise AssertionError(msg) + def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header='', precision=6): + header='', precision=6, equal_nan=True): + __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + def safe_comparison(*args, **kwargs): + # There are a number of cases where comparing two arrays hits special + # cases in array_richcompare, specifically around strings and void + # dtypes. Basically, we just can't do comparisons involving these + # types, unless both arrays have exactly the *same* type. So + # e.g. you can apply == to two string arrays, or two arrays with + # identical structured dtypes. But if you compare a non-string array + # to a string array, or two arrays with non-identical structured + # dtypes, or anything like that, then internally stuff blows up. + # Currently, when things blow up, we just return a scalar False or + # True. But we also emit a DeprecationWarning, b/c eventually we + # should raise an error here. (Ideally we might even make this work + # properly, but since that will require rewriting a bunch of how + # ufuncs work then we are not counting on that.) + # + # The point of this little function is to let the DeprecationWarning + # pass (or maybe eventually catch the errors and return False, I + # dunno, that's a little trickier and we can figure that out when the + # time comes). + with suppress_warnings() as sup: + sup.filter(DeprecationWarning, ".*==") + sup.filter(FutureWarning, ".*==") + return comparison(*args, **kwargs) + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -602,13 +706,13 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, assert_array_equal(x_id, y_id) except AssertionError: msg = build_err_msg([x, y], - err_msg + '\nx and y %s location mismatch:' \ + err_msg + '\nx and y %s location mismatch:' % (hasval), verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) try: - cond = (x.shape==() or y.shape==()) or x.shape == y.shape + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: msg = build_err_msg([x, y], err_msg @@ -616,36 +720,40 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, y.shape), verbose=verbose, header=header, names=('x', 'y'), precision=precision) - if not cond : + if not cond: raise AssertionError(msg) if isnumber(x) and isnumber(y): - x_isnan, y_isnan = isnan(x), isnan(y) + if equal_nan: + x_isnan, y_isnan = isnan(x), isnan(y) + # Validate that NaNs are in the same place + if any(x_isnan) or any(y_isnan): + chk_same_position(x_isnan, y_isnan, hasval='nan') + x_isinf, y_isinf = isinf(x), isinf(y) - # Validate that the special values are in the same place - if any(x_isnan) or any(y_isnan): - chk_same_position(x_isnan, y_isnan, hasval='nan') + # Validate that infinite values are in the same place if any(x_isinf) or any(y_isinf): # Check +inf and -inf separately, since they are different chk_same_position(x == +inf, y == +inf, hasval='+inf') chk_same_position(x == -inf, y == -inf, hasval='-inf') # Combine all the special values - x_id, y_id = x_isnan, y_isnan - x_id |= x_isinf - y_id |= y_isinf + x_id, y_id = x_isinf, y_isinf + if equal_nan: + x_id |= x_isnan + y_id |= y_isnan # Only do the comparison if actual values are left if all(x_id): return if any(x_id): - val = comparison(x[~x_id], y[~y_id]) + val = safe_comparison(x[~x_id], y[~y_id]) else: - val = comparison(x, y) + val = safe_comparison(x, y) else: - val = comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val @@ -661,9 +769,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, names=('x', 'y'), precision=precision) - if not cond : + if not cond: raise AssertionError(msg) - except ValueError as e: + except ValueError: import traceback efmt = traceback.format_exc() header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) @@ -672,6 +780,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, names=('x', 'y'), precision=precision) raise ValueError(msg) + def assert_array_equal(x, y, err_msg='', verbose=True): """ Raises an AssertionError if two array_like objects are not equal. @@ -738,6 +847,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not equal') + def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal up to desired @@ -748,14 +858,16 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): instead of this function for more consistent floating point comparisons. - The test verifies identical shapes and verifies values with - ``abs(desired-actual) < 0.5 * 10**(-decimal)``. + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. - Given two array_like objects, check that the shape is equal and all - elements of these objects are almost equal. An exception is raised at - shape mismatch or conflicting values. In contrast to the standard usage - in numpy, NaNs are compared like numbers, no assertion is raised if - both objects have NaNs in the same positions. + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. Parameters ---------- @@ -808,9 +920,11 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ + __tracebackhide__ = True # Hide traceback for py.test from numpy.core import around, number, float_, result_type, array from numpy.core.numerictypes import issubdtype from numpy.core.fromnumeric import any as npany + def compare(x, y): try: if npany(gisinf(x)) or npany( gisinf(y)): @@ -830,12 +944,12 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): # casting of x later. dtype = result_type(y, 1.) y = array(y, dtype=dtype, copy=False, subok=True) - z = abs(x-y) + z = abs(x - y) if not issubdtype(z.dtype, number): - z = z.astype(float_) # handle object arrays + z = z.astype(float_) # handle object arrays - return around(z, decimal) <= 10.0**(-decimal) + return z < 1.5 * 10.0**(-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header=('Arrays are not almost equal to %d decimals' % decimal), @@ -908,13 +1022,16 @@ def assert_array_less(x, y, err_msg='', verbose=True): y: array([4]) """ + __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not less-ordered') + def runstring(astr, dict): exec(astr, dict) + def assert_string_equal(actual, desired): """ Test if two strings are equal. @@ -942,9 +1059,10 @@ def assert_string_equal(actual, desired): """ # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test import difflib - if not isinstance(actual, str) : + if not isinstance(actual, str): raise AssertionError(repr(type(actual))) if not isinstance(desired, str): raise AssertionError(repr(type(desired))) @@ -963,14 +1081,15 @@ def assert_string_equal(actual, desired): if d2.startswith('? '): l.append(d2) d2 = diff.pop(0) - if not d2.startswith('+ ') : + if not d2.startswith('+ '): raise AssertionError(repr(d2)) l.append(d2) - d3 = diff.pop(0) - if d3.startswith('? '): - l.append(d3) - else: - diff.insert(0, d3) + if diff: + d3 = diff.pop(0) + if d3.startswith('? '): + l.append(d3) + else: + diff.insert(0, d3) if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]): continue diff_list.extend(l) @@ -979,7 +1098,7 @@ def assert_string_equal(actual, desired): if not diff_list: return msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip() - if actual != desired : + if actual != desired: raise AssertionError(msg) @@ -1005,17 +1124,13 @@ def rundocs(filename=None, raise_on_error=True): >>> np.lib.test(doctests=True) #doctest: +SKIP """ - import doctest, imp + from numpy.compat import npy_load_module + import doctest if filename is None: f = sys._getframe(1) filename = f.f_globals['__file__'] name = os.path.splitext(os.path.basename(filename))[0] - path = [os.path.dirname(filename)] - file, pathname, description = imp.find_module(name, path) - try: - m = imp.load_module(name, file, pathname, description) - finally: - file.close() + m = npy_load_module(name, filename) tests = doctest.DocTestFinder().find(m) runner = doctest.DocTestRunner(verbose=False) @@ -1038,9 +1153,10 @@ def raises(*args,**kwargs): return nose.tools.raises(*args,**kwargs) -def assert_raises(*args,**kwargs): +def assert_raises(*args, **kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) + assert_raises(exception_class) Fail unless an exception of class exception_class is thrown by callable when invoked with arguments args and keyword @@ -1049,87 +1165,54 @@ def assert_raises(*args,**kwargs): deemed to have suffered an error, exactly as for an unexpected exception. + Alternatively, `assert_raises` can be used as a context manager: + + >>> from numpy.testing import assert_raises + >>> with assert_raises(ZeroDivisionError): + ... 1 / 0 + + is equivalent to + + >>> def div(x, y): + ... return x / y + >>> assert_raises(ZeroDivisionError, div, 1, 0) + """ + __tracebackhide__ = True # Hide traceback for py.test nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) -assert_raises_regex_impl = None - - -def assert_raises_regex(exception_class, expected_regexp, - callable_obj=None, *args, **kwargs): +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + Fail unless an exception of class exception_class and with message that matches expected_regexp is thrown by callable when invoked with arguments args and keyword arguments kwargs. + Alternatively, can be used as a context manager like `assert_raises`. + Name of this function adheres to Python 3.2+ reference, but should work in all versions down to 2.6. + Notes + ----- + .. versionadded:: 1.9.0 + """ + __tracebackhide__ = True # Hide traceback for py.test nose = import_nose() - global assert_raises_regex_impl - if assert_raises_regex_impl is None: - try: - # Python 3.2+ - assert_raises_regex_impl = nose.tools.assert_raises_regex - except AttributeError: - try: - # 2.7+ - assert_raises_regex_impl = nose.tools.assert_raises_regexp - except AttributeError: - # 2.6 - - # This class is copied from Python2.7 stdlib almost verbatim - class _AssertRaisesContext(object): - """A context manager used to implement TestCase.assertRaises* methods.""" - - def __init__(self, expected, expected_regexp=None): - self.expected = expected - self.expected_regexp = expected_regexp - - def failureException(self, msg): - return AssertionError(msg) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, tb): - if exc_type is None: - try: - exc_name = self.expected.__name__ - except AttributeError: - exc_name = str(self.expected) - raise self.failureException( - "{0} not raised".format(exc_name)) - if not issubclass(exc_type, self.expected): - # let unexpected exceptions pass through - return False - self.exception = exc_value # store for later retrieval - if self.expected_regexp is None: - return True - - expected_regexp = self.expected_regexp - if isinstance(expected_regexp, basestring): - expected_regexp = re.compile(expected_regexp) - if not expected_regexp.search(str(exc_value)): - raise self.failureException( - '"%s" does not match "%s"' % - (expected_regexp.pattern, str(exc_value))) - return True - - def impl(cls, regex, callable_obj, *a, **kw): - mgr = _AssertRaisesContext(cls, regex) - if callable_obj is None: - return mgr - with mgr: - callable_obj(*a, **kw) - assert_raises_regex_impl = impl - - return assert_raises_regex_impl(exception_class, expected_regexp, - callable_obj, *args, **kwargs) + if sys.version_info.major >= 3: + funcname = nose.tools.assert_raises_regex + else: + # Only present in Python 2.7, missing from unittest in 2.6 + funcname = nose.tools.assert_raises_regexp + + return funcname(exception_class, expected_regexp, *args, **kwargs) def decorate_methods(cls, decorator, testmatch=None): @@ -1208,7 +1291,7 @@ def measure(code_str,times=1,label=None): -------- >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', ... times=times) - >>> print "Time for a single execution : ", etime / times, "s" + >>> print("Time for a single execution : ", etime / times, "s") Time for a single execution : 0.005 s """ @@ -1226,25 +1309,28 @@ def measure(code_str,times=1,label=None): elapsed = jiffies() - elapsed return 0.01*elapsed + def _assert_valid_refcount(op): """ Check that ufuncs don't mishandle refcount of object `1`. Used in a few regression tests. """ + if not HAS_REFCOUNT: + return True import numpy as np - a = np.arange(100 * 100) + b = np.arange(100*100).reshape(100, 100) c = b - i = 1 rc = sys.getrefcount(i) for j in range(15): d = op(b, c) - assert_(sys.getrefcount(i) >= rc) + del d # for pyflakes -def assert_allclose(actual, desired, rtol=1e-7, atol=0, + +def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal up to desired @@ -1266,6 +1352,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, Relative tolerance. atol : float, optional Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. err_msg : str, optional The error message to be printed in case of failure. verbose : bool, optional @@ -1287,14 +1375,18 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, >>> assert_allclose(x, y, rtol=1e-5, atol=0) """ + __tracebackhide__ = True # Hide traceback for py.test import numpy as np + def compare(x, y): - return np.core.numeric._allclose_points(x, y, rtol=rtol, atol=atol) + return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol, + equal_nan=equal_nan) actual, desired = np.asanyarray(actual), np.asanyarray(desired) header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) assert_array_compare(compare, actual, desired, err_msg=str(err_msg), - verbose=verbose, header=header) + verbose=verbose, header=header, equal_nan=equal_nan) + def assert_array_almost_equal_nulp(x, y, nulp=1): """ @@ -1331,7 +1423,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): ----- An assertion is raised if the following condition is not met:: - abs(x - y) <= nulps * spacing(max(abs(x), abs(y))) + abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y))) Examples -------- @@ -1345,6 +1437,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): AssertionError: X and Y are not equal to 1 ULP (max is 2) """ + __tracebackhide__ = True # Hide traceback for py.test import numpy as np ax = np.abs(x) ay = np.abs(y) @@ -1357,6 +1450,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) raise AssertionError(msg) + def assert_array_max_ulp(a, b, maxulp=1, dtype=None): """ Check that all items of arrays differ in at most N Units in the Last Place. @@ -1393,13 +1487,15 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) """ + __tracebackhide__ = True # Hide traceback for py.test import numpy as np ret = nulp_diff(a, b, dtype) if not np.all(ret <= maxulp): - raise AssertionError("Arrays are not almost equal up to %g ULP" % \ + raise AssertionError("Arrays are not almost equal up to %g ULP" % maxulp) return ret + def nulp_diff(x, y, dtype=None): """For each item in x and y, return the number of representable floating points between them. @@ -1410,6 +1506,8 @@ def nulp_diff(x, y, dtype=None): first input array y : array_like second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. Returns ------- @@ -1440,7 +1538,7 @@ def nulp_diff(x, y, dtype=None): y = np.array(y, dtype=t) if not x.shape == y.shape: - raise ValueError("x and y do not have the same shape: %s - %s" % \ + raise ValueError("x and y do not have the same shape: %s - %s" % (x.shape, y.shape)) def _diff(rx, ry, vdt): @@ -1451,6 +1549,7 @@ def nulp_diff(x, y, dtype=None): ry = integer_repr(y) return _diff(rx, ry, t) + def _integer_repr(x, vdt, comp): # Reinterpret binary representation of the float as sign-magnitude: # take into account two-complement representation @@ -1458,13 +1557,14 @@ def _integer_repr(x, vdt, comp): # http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm rx = x.view(vdt) if not (rx.size == 1): - rx[rx < 0] = comp - rx[rx<0] + rx[rx < 0] = comp - rx[rx < 0] else: if rx < 0: rx = comp - rx return rx + def integer_repr(x): """Return the signed-magnitude interpretation of the binary representation of x.""" @@ -1476,6 +1576,7 @@ def integer_repr(x): else: raise ValueError("Unsupported dtype %s" % x.dtype) + # The following two classes are copied from python 2.6 warnings module (context # manager) class WarningMessage(object): @@ -1510,6 +1611,7 @@ class WarningMessage(object): "line : %r}" % (self.message, self._category_name, self.filename, self.lineno, self.line)) + class WarningManager(object): """ A context manager that copies and restores the warnings filter upon @@ -1534,6 +1636,7 @@ class WarningManager(object): It is copied so it can be used in NumPy with older Python versions. """ + def __init__(self, record=False, module=None): self._record = record if module is None: @@ -1551,6 +1654,7 @@ class WarningManager(object): self._showwarning = self._module.showwarning if self._record: log = [] + def showwarning(*args, **kwargs): log.append(WarningMessage(*args, **kwargs)) self._module.showwarning = showwarning @@ -1565,14 +1669,32 @@ class WarningManager(object): self._module.showwarning = self._showwarning -def assert_warns(warning_class, func, *args, **kw): +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): """ Fail unless the given callable throws the specified warning. A warning of class warning_class should be thrown by the callable when invoked with arguments args and keyword arguments kwargs. - If a different type of warning is thrown, it will not be caught, and the - test case will be deemed to have suffered an error. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. .. versionadded:: 1.4.0 @@ -1592,21 +1714,37 @@ def assert_warns(warning_class, func, *args, **kw): The value returned by `func`. """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test with warnings.catch_warnings(record=True) as l: warnings.simplefilter('always') - result = func(*args, **kw) - if not len(l) > 0: - raise AssertionError("No warning raised when calling %s" - % func.__name__) - if not l[0].category is warning_class: - raise AssertionError("First warning for %s is not a " \ - "%s( is %s)" % (func.__name__, warning_class, l[0])) - return result - -def assert_no_warnings(func, *args, **kw): + yield + if len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("Got warnings%s: %s" % (name_str, l)) + + +def assert_no_warnings(*args, **kwargs): """ Fail if the given callable produces any warnings. + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + .. versionadded:: 1.7.0 Parameters @@ -1623,13 +1761,13 @@ def assert_no_warnings(func, *args, **kw): The value returned by `func`. """ - with warnings.catch_warnings(record=True) as l: - warnings.simplefilter('always') - result = func(*args, **kw) - if len(l) > 0: - raise AssertionError("Got warnings when calling %s: %s" - % (func.__name__, l)) - return result + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) def _gen_alignment_data(dtype=float32, type='binary', max_size=24): @@ -1662,10 +1800,11 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24): for o in range(3): for s in range(o + 2, max(o + 3, max_size)): if type == 'unary': - inp = lambda : arange(s, dtype=dtype)[o:] + inp = lambda: arange(s, dtype=dtype)[o:] out = empty((s,), dtype=dtype)[o:] yield out, inp(), ufmt % (o, o, s, dtype, 'out of place') - yield inp(), inp(), ufmt % (o, o, s, dtype, 'in place') + d = inp() + yield d, d, ufmt % (o, o, s, dtype, 'in place') yield out[1:], inp()[:-1], ufmt % \ (o + 1, o, s - 1, dtype, 'out of place') yield out[:-1], inp()[1:], ufmt % \ @@ -1675,14 +1814,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24): yield inp()[1:], inp()[:-1], ufmt % \ (o + 1, o, s - 1, dtype, 'aliased') if type == 'binary': - inp1 = lambda :arange(s, dtype=dtype)[o:] - inp2 = lambda :arange(s, dtype=dtype)[o:] + inp1 = lambda: arange(s, dtype=dtype)[o:] + inp2 = lambda: arange(s, dtype=dtype)[o:] out = empty((s,), dtype=dtype)[o:] yield out, inp1(), inp2(), bfmt % \ (o, o, o, s, dtype, 'out of place') - yield inp1(), inp1(), inp2(), bfmt % \ + d = inp1() + yield d, d, inp2(), bfmt % \ (o, o, o, s, dtype, 'in place1') - yield inp2(), inp1(), inp2(), bfmt % \ + d = inp2() + yield d, inp1(), d, bfmt % \ (o, o, o, s, dtype, 'in place2') yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \ (o + 1, o, o, s - 1, dtype, 'out of place') @@ -1711,5 +1852,380 @@ def tempdir(*args, **kwargs): """ tmpdir = mkdtemp(*args, **kwargs) - yield tmpdir - shutil.rmtree(tmpdir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """ Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super(clear_and_catch_warnings, self).__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, '__warningregistry__'): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super(clear_and_catch_warnings, self).__enter__() + + def __exit__(self, *exc_info): + super(clear_and_catch_warnings, self).__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, '__warningregistry__'): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings(object): + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + http://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + >>> with suppress_warnings() as sup: + ... sup.filter(DeprecationWarning, "Some text") + ... sup.filter(module=np.ma.core) + ... log = sup.record(FutureWarning, "Does this occur?") + ... command_giving_warnings() + ... # The FutureWarning was given once, the filtered warnings were + ... # ignored. All other warnings abide outside settings (may be + ... # printed/error) + ... assert_(len(log) == 1) + ... assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator: + + >>> sup = suppress_warnings() + >>> sup.filter(module=np.ma.core) # module must match exact + >>> @sup + >>> def some_function(): + ... # do something which causes a warning in np.ma.core + ... pass + """ + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings( + "always", category=category, message=message) + else: + module_regex = module.__name__.replace('.', '\.') + '$' + warnings.filterwarnings( + "always", category=category, message=message, + module=module_regex) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.I), module, record)) + else: + self._suppressions.append( + (category, message, re.compile(message, re.I), module, record)) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, + record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter(category=category, message=message, module=module, + record=True) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + if hasattr(warnings, "_showwarnmsg"): + self._orig_showmsg = warnings._showwarnmsg + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings( + "always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace('.', '\.') + '$' + warnings.filterwarnings( + "always", category=cat, message=mess, + module=module_regex) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + if hasattr(warnings, "_showwarnmsg"): + warnings._showwarnmsg = self._showwarnmsg + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + if hasattr(warnings, "_showwarnmsg"): + warnings._showwarnmsg = self._orig_showmsg + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarnmsg(self, msg): + self._showwarning(msg.message, msg.category, msg.filename, msg.lineno, + msg.file, msg.line, use_warnmsg=msg) + + def _showwarning(self, message, category, filename, lineno, + *args, **kwargs): + use_warnmsg = kwargs.pop("use_warnmsg", None) + for cat, _, pattern, mod, rec in ( + self._suppressions + self._tmp_suppressions)[::-1]: + if (issubclass(category, cat) and + pattern.match(message.args[0]) is not None): + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage(message, category, filename, + lineno, **kwargs) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage(message, category, filename, + lineno, **kwargs) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, + *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, + **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func |