diff options
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 114 |
1 files changed, 102 insertions, 12 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4905898d2..70357d835 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,8 +9,9 @@ import sys import re import operator import warnings +from functools import partial from .nosetester import import_nose -from numpy.core import float32, empty, arange +from numpy.core import float32, empty, arange, array_repr, ndarray if sys.version_info[0] >= 3: from io import StringIO @@ -22,7 +23,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', - 'assert_', 'assert_array_almost_equal_nulp', + 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException'] @@ -190,8 +191,7 @@ if os.name=='nt' and sys.version[:3] > '2.3': win32pdh.PDH_FMT_LONG, None) def build_err_msg(arrays, err_msg, header='Items are not equal:', - verbose=True, - names=('ACTUAL', 'DESIRED')): + verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg.append(err_msg) if verbose: for i, a in enumerate(arrays): + + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + r_func = partial(array_repr, precision=precision) + else: + r_func = repr + try: - r = repr(a) + r = r_func(a) except: r = '[repr failed]' if r.count('\n') > 3: @@ -318,7 +325,9 @@ def assert_equal(actual,desired,err_msg='',verbose=True): # as before except (TypeError, ValueError, NotImplementedError): pass - if desired != actual : + + # Explicitly use __eq__ for comparison, ticket #2552 + if not (desired == actual): raise AssertionError(msg) def print_assert_equal(test_string, actual, desired): @@ -573,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header=''): + header='', precision=6): from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -590,7 +599,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' \ % (hasval), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise AssertionError(msg) try: @@ -601,7 +610,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) @@ -646,7 +655,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, err_msg + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) except ValueError as e: @@ -655,7 +664,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise ValueError(msg) def assert_array_equal(x, y, err_msg='', verbose=True): @@ -823,7 +832,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): return around(z, decimal) <= 10.0**(-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal)) + header=('Arrays are not almost equal to %d decimals' % decimal), + precision=decimal) def assert_array_less(x, y, err_msg='', verbose=True): @@ -1020,6 +1030,7 @@ def raises(*args,**kwargs): nose = import_nose() return nose.tools.raises(*args,**kwargs) + def assert_raises(*args,**kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) @@ -1035,6 +1046,85 @@ def assert_raises(*args,**kwargs): nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) + +assert_raises_regex_impl = None + + +def assert_raises_regex(exception_class, expected_regexp, + callable_obj=None, *args, **kwargs): + """ + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Name of this function adheres to Python 3.2+ reference, but should work in + all versions down to 2.6. + + """ + nose = import_nose() + + global assert_raises_regex_impl + if assert_raises_regex_impl is None: + try: + # Python 3.2+ + assert_raises_regex_impl = nose.tools.assert_raises_regex + except AttributeError: + try: + # 2.7+ + assert_raises_regex_impl = nose.tools.assert_raises_regexp + except AttributeError: + # 2.6 + + # This class is copied from Python2.7 stdlib almost verbatim + class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __init__(self, expected, expected_regexp=None): + self.expected = expected + self.expected_regexp = expected_regexp + + def failureException(self, msg): + return AssertionError(msg) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + raise self.failureException( + "{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value # store for later retrieval + if self.expected_regexp is None: + return True + + expected_regexp = self.expected_regexp + if isinstance(expected_regexp, basestring): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException( + '"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + + def impl(cls, regex, callable_obj, *a, **kw): + mgr = _AssertRaisesContext(cls, regex) + if callable_obj is None: + return mgr + with mgr: + callable_obj(*a, **kw) + assert_raises_regex_impl = impl + + return assert_raises_regex_impl(exception_class, expected_regexp, + callable_obj, *args, **kwargs) + + def decorate_methods(cls, decorator, testmatch=None): """ Apply a decorator to all methods in a class matching a regular expression. |