diff options
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 288 |
1 files changed, 174 insertions, 114 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index c6d863f94..176d87800 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -12,9 +12,12 @@ import warnings from functools import partial import shutil import contextlib -from tempfile import mkdtemp +from tempfile import mkdtemp, mkstemp +from unittest.case import SkipTest + from .nosetester import import_nose from numpy.core import float32, empty, arange, array_repr, ndarray +from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: from io import StringIO @@ -28,11 +31,21 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', - 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings'] + 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', + 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', + 'HAS_REFCOUNT'] + + +class KnownFailureException(Exception): + '''Raise this exception to mark a test as a known failing test.''' + pass +KnownFailureTest = KnownFailureException # backwards compat verbose = 0 +IS_PYPY = '__pypy__' in sys.modules +HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None def assert_(val, msg=''): """ @@ -52,6 +65,7 @@ def assert_(val, msg=''): smsg = msg raise AssertionError(smsg) + def gisnan(x): """like isnan, but always raise an error if type not supported instead of returning a TypeError object. @@ -69,6 +83,7 @@ def gisnan(x): raise TypeError("isnan not supported for this type") return st + def gisfinite(x): """like isfinite, but always raise an error if type not supported instead of returning a TypeError object. @@ -87,6 +102,7 @@ def gisfinite(x): raise TypeError("isfinite not supported for this type") return st + def gisinf(x): """like isinf, but always raise an error if type not supported instead of returning a TypeError object. @@ -105,6 +121,9 @@ def gisinf(x): raise TypeError("isinf not supported for this type") return st + +@deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. " + "Use numpy.random.rand instead.") def rand(*args): """Returns an array of random numbers with the given shape. @@ -118,6 +137,7 @@ def rand(*args): f[i] = random.random() return results + if os.name == 'nt': # Code "stolen" from enthought/debug/memusage.py def GetPerformanceAttributes(object, counter, instance=None, @@ -232,14 +252,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', try: r = r_func(a) - except: - r = '[repr failed]' + except Exception as exc: + r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc) if r.count('\n') > 3: r = '\n'.join(r.splitlines()[:3]) r += '...' msg.append(' %s: %s' % (names[i], r)) return '\n'.join(msg) + def assert_equal(actual,desired,err_msg='',verbose=True): """ Raises an AssertionError if two objects are not equal. @@ -354,6 +375,7 @@ def assert_equal(actual,desired,err_msg='',verbose=True): if not (desired == actual): raise AssertionError(msg) + def print_assert_equal(test_string, actual, desired): """ Test if two objects are equal, and print an error message if test fails. @@ -394,6 +416,7 @@ def print_assert_equal(test_string, actual, desired): pprint.pprint(desired, msg) raise AssertionError(msg.getvalue()) + def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to desired @@ -404,11 +427,14 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): instead of this function for more consistent floating point comparisons. - The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``. + The test verifies that the elements of ``actual`` and ``desired`` satisfy. - Given two objects (numbers or ndarrays), check that all elements of these - objects are almost equal. An exception is raised at conflicting values. - For ndarrays this delegates to assert_array_almost_equal + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal Parameters ---------- @@ -509,7 +535,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): return except (NotImplementedError, TypeError): pass - if round(abs(desired - actual), decimal) != 0: + if abs(desired - actual) >= 1.5 * 10.0**(-decimal): raise AssertionError(_build_err_msg()) @@ -610,6 +636,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): raise AssertionError(msg) + def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6): __tracebackhide__ = True # Hide traceback for py.test @@ -720,6 +747,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, names=('x', 'y'), precision=precision) raise ValueError(msg) + def assert_array_equal(x, y, err_msg='', verbose=True): """ Raises an AssertionError if two array_like objects are not equal. @@ -786,6 +814,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not equal') + def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal up to desired @@ -796,14 +825,16 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): instead of this function for more consistent floating point comparisons. - The test verifies identical shapes and verifies values with - ``abs(desired-actual) < 0.5 * 10**(-decimal)``. + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. - Given two array_like objects, check that the shape is equal and all - elements of these objects are almost equal. An exception is raised at - shape mismatch or conflicting values. In contrast to the standard usage - in numpy, NaNs are compared like numbers, no assertion is raised if - both objects have NaNs in the same positions. + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. Parameters ---------- @@ -880,12 +911,12 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): # casting of x later. dtype = result_type(y, 1.) y = array(y, dtype=dtype, copy=False, subok=True) - z = abs(x-y) + z = abs(x - y) if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays - return around(z, decimal) <= 10.0**(-decimal) + return z < 1.5 * 10.0**(-decimal) assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header=('Arrays are not almost equal to %d decimals' % decimal), @@ -963,9 +994,11 @@ def assert_array_less(x, y, err_msg='', verbose=True): verbose=verbose, header='Arrays are not less-ordered') + def runstring(astr, dict): exec(astr, dict) + def assert_string_equal(actual, desired): """ Test if two strings are equal. @@ -1018,11 +1051,12 @@ def assert_string_equal(actual, desired): if not d2.startswith('+ '): raise AssertionError(repr(d2)) l.append(d2) - d3 = diff.pop(0) - if d3.startswith('? '): - l.append(d3) - else: - diff.insert(0, d3) + if diff: + d3 = diff.pop(0) + if d3.startswith('? '): + l.append(d3) + else: + diff.insert(0, d3) if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]): continue diff_list.extend(l) @@ -1057,18 +1091,13 @@ def rundocs(filename=None, raise_on_error=True): >>> np.lib.test(doctests=True) #doctest: +SKIP """ + from numpy.compat import npy_load_module import doctest - import imp if filename is None: f = sys._getframe(1) filename = f.f_globals['__file__'] name = os.path.splitext(os.path.basename(filename))[0] - path = [os.path.dirname(filename)] - file, pathname, description = imp.find_module(name, path) - try: - m = imp.load_module(name, file, pathname, description) - finally: - file.close() + m = npy_load_module(name, filename) tests = doctest.DocTestFinder().find(m) runner = doctest.DocTestRunner(verbose=False) @@ -1102,15 +1131,24 @@ def assert_raises(*args,**kwargs): deemed to have suffered an error, exactly as for an unexpected exception. + Alternatively, `assert_raises` can be used as a context manager: + + >>> from numpy.testing import assert_raises + >>> with assert_raises(ZeroDivisionError): + ... 1 / 0 + + is equivalent to + + >>> def div(x, y): + ... return x / y + >>> assert_raises(ZeroDivisionError, div, 1, 0) + """ __tracebackhide__ = True # Hide traceback for py.test nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) -assert_raises_regex_impl = None - - def assert_raises_regex(exception_class, expected_regexp, callable_obj=None, *args, **kwargs): """ @@ -1121,70 +1159,22 @@ def assert_raises_regex(exception_class, expected_regexp, Name of this function adheres to Python 3.2+ reference, but should work in all versions down to 2.6. + Notes + ----- + .. versionadded:: 1.9.0 + """ __tracebackhide__ = True # Hide traceback for py.test nose = import_nose() - global assert_raises_regex_impl - if assert_raises_regex_impl is None: - try: - # Python 3.2+ - assert_raises_regex_impl = nose.tools.assert_raises_regex - except AttributeError: - try: - # 2.7+ - assert_raises_regex_impl = nose.tools.assert_raises_regexp - except AttributeError: - # 2.6 - - # This class is copied from Python2.7 stdlib almost verbatim - class _AssertRaisesContext(object): - """A context manager used to implement TestCase.assertRaises* methods.""" - - def __init__(self, expected, expected_regexp=None): - self.expected = expected - self.expected_regexp = expected_regexp - - def failureException(self, msg): - return AssertionError(msg) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, tb): - if exc_type is None: - try: - exc_name = self.expected.__name__ - except AttributeError: - exc_name = str(self.expected) - raise self.failureException( - "{0} not raised".format(exc_name)) - if not issubclass(exc_type, self.expected): - # let unexpected exceptions pass through - return False - self.exception = exc_value # store for later retrieval - if self.expected_regexp is None: - return True - - expected_regexp = self.expected_regexp - if isinstance(expected_regexp, basestring): - expected_regexp = re.compile(expected_regexp) - if not expected_regexp.search(str(exc_value)): - raise self.failureException( - '"%s" does not match "%s"' % - (expected_regexp.pattern, str(exc_value))) - return True - - def impl(cls, regex, callable_obj, *a, **kw): - mgr = _AssertRaisesContext(cls, regex) - if callable_obj is None: - return mgr - with mgr: - callable_obj(*a, **kw) - assert_raises_regex_impl = impl - - return assert_raises_regex_impl(exception_class, expected_regexp, - callable_obj, *args, **kwargs) + if sys.version_info.major >= 3: + funcname = nose.tools.assert_raises_regex + else: + # Only present in Python 2.7, missing from unittest in 2.6 + funcname = nose.tools.assert_raises_regexp + + return funcname(exception_class, expected_regexp, callable_obj, + *args, **kwargs) def decorate_methods(cls, decorator, testmatch=None): @@ -1263,7 +1253,7 @@ def measure(code_str,times=1,label=None): -------- >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', ... times=times) - >>> print "Time for a single execution : ", etime / times, "s" + >>> print("Time for a single execution : ", etime / times, "s") Time for a single execution : 0.005 s """ @@ -1287,6 +1277,8 @@ def _assert_valid_refcount(op): Check that ufuncs don't mishandle refcount of object `1`. Used in a few regression tests. """ + if not HAS_REFCOUNT: + return True import numpy as np b = np.arange(100*100).reshape(100, 100) @@ -1357,6 +1349,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False, assert_array_compare(compare, actual, desired, err_msg=str(err_msg), verbose=verbose, header=header) + def assert_array_almost_equal_nulp(x, y, nulp=1): """ Compare two arrays relatively to their spacing. @@ -1419,6 +1412,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) raise AssertionError(msg) + def assert_array_max_ulp(a, b, maxulp=1, dtype=None): """ Check that all items of arrays differ in at most N Units in the Last Place. @@ -1463,6 +1457,7 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): maxulp) return ret + def nulp_diff(x, y, dtype=None): """For each item in x and y, return the number of representable floating points between them. @@ -1516,6 +1511,7 @@ def nulp_diff(x, y, dtype=None): ry = integer_repr(y) return _diff(rx, ry, t) + def _integer_repr(x, vdt, comp): # Reinterpret binary representation of the float as sign-magnitude: # take into account two-complement representation @@ -1530,6 +1526,7 @@ def _integer_repr(x, vdt, comp): return rx + def integer_repr(x): """Return the signed-magnitude interpretation of the binary representation of x.""" @@ -1541,6 +1538,7 @@ def integer_repr(x): else: raise ValueError("Unsupported dtype %s" % x.dtype) + # The following two classes are copied from python 2.6 warnings module (context # manager) class WarningMessage(object): @@ -1575,6 +1573,7 @@ class WarningMessage(object): "line : %r}" % (self.message, self._category_name, self.filename, self.lineno, self.line)) + class WarningManager(object): """ A context manager that copies and restores the warnings filter upon @@ -1632,7 +1631,22 @@ class WarningManager(object): self._module.showwarning = self._showwarning -def assert_warns(warning_class, func, *args, **kw): +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter('always') + yield + if not len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("No warning raised" + name_str) + if not l[0].category is warning_class: + name_str = "%s " % name if name is not None else "" + raise AssertionError("First warning %sis not a %s (is %s)" + % (name_str, warning_class, l[0])) + + +def assert_warns(warning_class, *args, **kwargs): """ Fail unless the given callable throws the specified warning. @@ -1641,6 +1655,14 @@ def assert_warns(warning_class, func, *args, **kw): If a different type of warning is thrown, it will not be caught, and the test case will be deemed to have suffered an error. + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + .. versionadded:: 1.4.0 Parameters @@ -1659,22 +1681,37 @@ def assert_warns(warning_class, func, *args, **kw): The value returned by `func`. """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): __tracebackhide__ = True # Hide traceback for py.test with warnings.catch_warnings(record=True) as l: warnings.simplefilter('always') - result = func(*args, **kw) - if not len(l) > 0: - raise AssertionError("No warning raised when calling %s" - % func.__name__) - if not l[0].category is warning_class: - raise AssertionError("First warning for %s is not a " - "%s( is %s)" % (func.__name__, warning_class, l[0])) - return result + yield + if len(l) > 0: + name_str = " when calling %s" % name if name is not None else "" + raise AssertionError("Got warnings%s: %s" % (name_str, l)) + -def assert_no_warnings(func, *args, **kw): +def assert_no_warnings(*args, **kwargs): """ Fail if the given callable produces any warnings. + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + .. versionadded:: 1.7.0 Parameters @@ -1691,14 +1728,13 @@ def assert_no_warnings(func, *args, **kw): The value returned by `func`. """ - __tracebackhide__ = True # Hide traceback for py.test - with warnings.catch_warnings(record=True) as l: - warnings.simplefilter('always') - result = func(*args, **kw) - if len(l) > 0: - raise AssertionError("Got warnings when calling %s: %s" - % (func.__name__, l)) - return result + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) def _gen_alignment_data(dtype=float32, type='binary', max_size=24): @@ -1780,8 +1816,32 @@ def tempdir(*args, **kwargs): """ tmpdir = mkdtemp(*args, **kwargs) - yield tmpdir - shutil.rmtree(tmpdir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) class clear_and_catch_warnings(warnings.catch_warnings): |