diff options
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 149 |
1 files changed, 135 insertions, 14 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 2a99fe5cb..13c3e4610 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,8 +9,12 @@ import sys import re import operator import warnings +from functools import partial +import shutil +import contextlib +from tempfile import mkdtemp from .nosetester import import_nose -from numpy.core import float32, empty, arange +from numpy.core import float32, empty, arange, array_repr, ndarray if sys.version_info[0] >= 3: from io import StringIO @@ -22,7 +26,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', - 'assert_', 'assert_array_almost_equal_nulp', + 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException'] @@ -190,8 +194,7 @@ if os.name=='nt' and sys.version[:3] > '2.3': win32pdh.PDH_FMT_LONG, None) def build_err_msg(arrays, err_msg, header='Items are not equal:', - verbose=True, - names=('ACTUAL', 'DESIRED')): + verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -200,8 +203,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg.append(err_msg) if verbose: for i, a in enumerate(arrays): + + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + r_func = partial(array_repr, precision=precision) + else: + r_func = repr + try: - r = repr(a) + r = r_func(a) except: r = '[repr failed]' if r.count('\n') > 3: @@ -318,7 +328,9 @@ def assert_equal(actual,desired,err_msg='',verbose=True): # as before except (TypeError, ValueError, NotImplementedError): pass - if desired != actual : + + # Explicitly use __eq__ for comparison, ticket #2552 + if not (desired == actual): raise AssertionError(msg) def print_assert_equal(test_string, actual, desired): @@ -573,7 +585,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header=''): + header='', precision=6): from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -590,7 +602,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' \ % (hasval), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise AssertionError(msg) try: @@ -601,7 +613,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) @@ -646,7 +658,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, err_msg + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) except ValueError as e: @@ -655,7 +667,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise ValueError(msg) def assert_array_equal(x, y, err_msg='', verbose=True): @@ -793,7 +805,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ - from numpy.core import around, number, float_ + from numpy.core import around, number, float_, result_type, array from numpy.core.numerictypes import issubdtype from numpy.core.fromnumeric import any as npany def compare(x, y): @@ -810,12 +822,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y = y[~yinfid] except (TypeError, NotImplementedError): pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.) + y = array(y, dtype=dtype, copy=False) z = abs(x-y) + if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays + return around(z, decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal)) + header=('Arrays are not almost equal to %d decimals' % decimal), + precision=decimal) + def assert_array_less(x, y, err_msg='', verbose=True): """ @@ -1011,6 +1033,7 @@ def raises(*args,**kwargs): nose = import_nose() return nose.tools.raises(*args,**kwargs) + def assert_raises(*args,**kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) @@ -1026,6 +1049,85 @@ def assert_raises(*args,**kwargs): nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) + +assert_raises_regex_impl = None + + +def assert_raises_regex(exception_class, expected_regexp, + callable_obj=None, *args, **kwargs): + """ + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Name of this function adheres to Python 3.2+ reference, but should work in + all versions down to 2.6. + + """ + nose = import_nose() + + global assert_raises_regex_impl + if assert_raises_regex_impl is None: + try: + # Python 3.2+ + assert_raises_regex_impl = nose.tools.assert_raises_regex + except AttributeError: + try: + # 2.7+ + assert_raises_regex_impl = nose.tools.assert_raises_regexp + except AttributeError: + # 2.6 + + # This class is copied from Python2.7 stdlib almost verbatim + class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __init__(self, expected, expected_regexp=None): + self.expected = expected + self.expected_regexp = expected_regexp + + def failureException(self, msg): + return AssertionError(msg) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + raise self.failureException( + "{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value # store for later retrieval + if self.expected_regexp is None: + return True + + expected_regexp = self.expected_regexp + if isinstance(expected_regexp, basestring): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException( + '"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + + def impl(cls, regex, callable_obj, *a, **kw): + mgr = _AssertRaisesContext(cls, regex) + if callable_obj is None: + return mgr + with mgr: + callable_obj(*a, **kw) + assert_raises_regex_impl = impl + + return assert_raises_regex_impl(exception_class, expected_regexp, + callable_obj, *args, **kwargs) + + def decorate_methods(cls, decorator, testmatch=None): """ Apply a decorator to all methods in a class matching a regular expression. @@ -1147,6 +1249,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, It compares the difference between `actual` and `desired` to ``atol + rtol * abs(desired)``. + .. versionadded:: 1.5.0 + Parameters ---------- actual : array_like @@ -1231,7 +1335,6 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) - ------------------------------------------------------------ Traceback (most recent call last): ... AssertionError: X and Y are not equal to 1 ULP (max is 2) @@ -1456,6 +1559,7 @@ class WarningManager(object): self._module.filters = self._filters self._module.showwarning = self._showwarning + def assert_warns(warning_class, func, *args, **kw): """ Fail unless the given callable throws the specified warning. @@ -1465,6 +1569,8 @@ def assert_warns(warning_class, func, *args, **kw): If a different type of warning is thrown, it will not be caught, and the test case will be deemed to have suffered an error. + .. versionadded:: 1.4.0 + Parameters ---------- warning_class : class @@ -1496,6 +1602,8 @@ def assert_no_warnings(func, *args, **kw): """ Fail if the given callable produces any warnings. + .. versionadded:: 1.7.0 + Parameters ---------- func : callable @@ -1587,3 +1695,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24): class IgnoreException(Exception): "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + yield tmpdir + shutil.rmtree(tmpdir) |