diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/__init__.py | 1 | ||||
-rw-r--r-- | numpy/testing/nosetester.py | 49 | ||||
-rw-r--r-- | numpy/testing/numpytest.py | 36 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 76 | ||||
-rw-r--r-- | numpy/testing/utils.py | 149 |
5 files changed, 258 insertions, 53 deletions
diff --git a/numpy/testing/__init__.py b/numpy/testing/__init__.py index aad448c96..258cbe928 100644 --- a/numpy/testing/__init__.py +++ b/numpy/testing/__init__.py @@ -11,7 +11,6 @@ from unittest import TestCase from . import decorators as dec from .utils import * -from .numpytest import importall # remove for numpy 1.9.0 from .nosetester import NoseTester as Tester from .nosetester import run_module_suite test = Tester().test diff --git a/numpy/testing/nosetester.py b/numpy/testing/nosetester.py index 77ac61074..a06d559e7 100644 --- a/numpy/testing/nosetester.py +++ b/numpy/testing/nosetester.py @@ -75,14 +75,55 @@ def import_nose(): return nose -def run_module_suite(file_to_run = None): +def run_module_suite(file_to_run=None, argv=None): + """ + Run a test module. + + Equivalent to calling ``$ nosetests <argv> <file_to_run>`` from + the command line + + Parameters + ---------- + file_to_run: str, optional + Path to test module, or None. + By default, run the module from which this function is called. + argv: list of strings + Arguments to be passed to the nose test runner. ``argv[0]`` is + ignored. All command line arguments accepted by ``nosetests`` + will work. + + .. versionadded:: 1.9.0 + + Examples + -------- + Adding the following:: + + if __name__ == "__main__" : + run_module_suite(argv=sys.argv) + + at the end of a test module will run the tests when that module is + called in the python interpreter. + + Alternatively, calling:: + + >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py") + + from an interpreter will run all the test routine in 'test_matlib.py'. + """ if file_to_run is None: f = sys._getframe(1) file_to_run = f.f_locals.get('__file__', None) if file_to_run is None: raise AssertionError - import_nose().run(argv=['', file_to_run]) + if argv is None: + argv = ['', file_to_run] + else: + argv = argv + [file_to_run] + + nose = import_nose() + from .noseclasses import KnownFailure + nose.run(argv=argv, addplugins=[KnownFailure()]) class NoseTester(object): @@ -378,6 +419,10 @@ class NoseTester(object): warnings.filterwarnings("ignore", message="numpy.ufunc size changed") warnings.filterwarnings("ignore", category=ModuleDeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) + # Filter out boolean '-' deprecation messages. This allows + # older versions of scipy to test without a flood of messages. + warnings.filterwarnings("ignore", message=".*boolean negative.*") + warnings.filterwarnings("ignore", message=".*boolean subtract.*") from .noseclasses import NumpyTestProgram diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py deleted file mode 100644 index d83a21d83..000000000 --- a/numpy/testing/numpytest.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import division, absolute_import, print_function - -import os -import warnings - -__all__ = ['importall'] - - -def importall(package): - """ - `importall` is DEPRECATED and will be removed in numpy 1.9.0 - - Try recursively to import all subpackages under package. - """ - warnings.warn("`importall is deprecated, and will be remobed in numpy 1.9.0", - DeprecationWarning) - - if isinstance(package, str): - package = __import__(package) - - package_name = package.__name__ - package_dir = os.path.dirname(package.__file__) - for subpackage_name in os.listdir(package_dir): - subdir = os.path.join(package_dir, subpackage_name) - if not os.path.isdir(subdir): - continue - if not os.path.isfile(os.path.join(subdir, '__init__.py')): - continue - name = package_name+'.'+subpackage_name - try: - exec('import %s as m' % (name)) - except Exception as msg: - print('Failed importing %s: %s' %(name, msg)) - continue - importall(m) - return diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 94fc4d655..aa0a2669f 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -53,6 +53,9 @@ class _GenericTest(object): a = np.array([1, 1], dtype=np.object) self._test_equal(a, 1) + def test_array_likes(self): + self._test_equal([1, 2, 3], (1, 2, 3)) + class TestArrayEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_equal @@ -131,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase): self._test_not_equal(c, b) +class TestBuildErrorMessage(unittest.TestCase): + def test_build_err_msg_defaults(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, ' + '2.00003, 3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_no_verbose(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, verbose=False) + b = '\nItems are not equal: There is a mismatch' + self.assertEqual(a, b) + + def test_build_err_msg_custom_names(self): + x = np.array([1.00001, 2.00002, 3.00003]) + y = np.array([1.00002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR')) + b = ('\nItems are not equal: There is a mismatch\n FOO: array([ ' + '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, ' + '3.00004])') + self.assertEqual(a, b) + + def test_build_err_msg_custom_precision(self): + x = np.array([1.000000001, 2.00002, 3.00003]) + y = np.array([1.000000002, 2.00003, 3.00004]) + err_msg = 'There is a mismatch' + + a = build_err_msg([x, y], err_msg, precision=10) + b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ ' + '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ ' + '1.000000002, 2.00003 , 3.00004 ])') + self.assertEqual(a, b) + class TestEqual(TestArrayEqual): def setUp(self): self._assert_func = assert_equal @@ -236,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase): self._test_not_equal(x, y) self._test_not_equal(x, z) + def test_error_message(self): + """Check the message is formatted correctly for the decimal value""" + x = np.array([1.00000000001, 2.00000000002, 3.00003]) + y = np.array([1.00000000002, 2.00000000003, 3.00004]) + + # test with a different amount of decimal digits + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 ' + ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])') + try: + self._assert_func(x, y, decimal=12) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + + # with the default value of decimal digits, only the 3rd element differs + # note that we only check for the formatting of the arrays themselves + b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , ' + '2. , 3.00004])') + try: + self._assert_func(x, y) + except AssertionError as e: + # remove anything that's not the array string + self.assertEqual(str(e).split('%)\n ')[1], b) + class TestApproxEqual(unittest.TestCase): def setUp(self): self._assert_func = assert_approx_equal @@ -373,6 +444,11 @@ class TestAssertAllclose(unittest.TestCase): assert_allclose(6, 10, rtol=0.5) self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5) + def test_min_int(self): + a = np.array([np.iinfo(np.int_).min], dtype=np.int_) + # Should not raise: + assert_allclose(a, a) + class TestArrayAlmostEqualNulp(unittest.TestCase): @dec.knownfailureif(True, "Github issue #347") diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 2a99fe5cb..13c3e4610 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -9,8 +9,12 @@ import sys import re import operator import warnings +from functools import partial +import shutil +import contextlib +from tempfile import mkdtemp from .nosetester import import_nose -from numpy.core import float32, empty, arange +from numpy.core import float32, empty, arange, array_repr, ndarray if sys.version_info[0] >= 3: from io import StringIO @@ -22,7 +26,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', - 'assert_', 'assert_array_almost_equal_nulp', + 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException'] @@ -190,8 +194,7 @@ if os.name=='nt' and sys.version[:3] > '2.3': win32pdh.PDH_FMT_LONG, None) def build_err_msg(arrays, err_msg, header='Items are not equal:', - verbose=True, - names=('ACTUAL', 'DESIRED')): + verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): msg = ['\n' + header] if err_msg: if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): @@ -200,8 +203,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg.append(err_msg) if verbose: for i, a in enumerate(arrays): + + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + r_func = partial(array_repr, precision=precision) + else: + r_func = repr + try: - r = repr(a) + r = r_func(a) except: r = '[repr failed]' if r.count('\n') > 3: @@ -318,7 +328,9 @@ def assert_equal(actual,desired,err_msg='',verbose=True): # as before except (TypeError, ValueError, NotImplementedError): pass - if desired != actual : + + # Explicitly use __eq__ for comparison, ticket #2552 + if not (desired == actual): raise AssertionError(msg) def print_assert_equal(test_string, actual, desired): @@ -573,7 +585,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header=''): + header='', precision=6): from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -590,7 +602,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' \ % (hasval), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise AssertionError(msg) try: @@ -601,7 +613,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + '\n(shapes %s, %s mismatch)' % (x.shape, y.shape), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) @@ -646,7 +658,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, err_msg + '\n(mismatch %s%%)' % (match,), verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) if not cond : raise AssertionError(msg) except ValueError as e: @@ -655,7 +667,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, - names=('x', 'y')) + names=('x', 'y'), precision=precision) raise ValueError(msg) def assert_array_equal(x, y, err_msg='', verbose=True): @@ -793,7 +805,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ - from numpy.core import around, number, float_ + from numpy.core import around, number, float_, result_type, array from numpy.core.numerictypes import issubdtype from numpy.core.fromnumeric import any as npany def compare(x, y): @@ -810,12 +822,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y = y[~yinfid] except (TypeError, NotImplementedError): pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.) + y = array(y, dtype=dtype, copy=False) z = abs(x-y) + if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays + return around(z, decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal)) + header=('Arrays are not almost equal to %d decimals' % decimal), + precision=decimal) + def assert_array_less(x, y, err_msg='', verbose=True): """ @@ -1011,6 +1033,7 @@ def raises(*args,**kwargs): nose = import_nose() return nose.tools.raises(*args,**kwargs) + def assert_raises(*args,**kwargs): """ assert_raises(exception_class, callable, *args, **kwargs) @@ -1026,6 +1049,85 @@ def assert_raises(*args,**kwargs): nose = import_nose() return nose.tools.assert_raises(*args,**kwargs) + +assert_raises_regex_impl = None + + +def assert_raises_regex(exception_class, expected_regexp, + callable_obj=None, *args, **kwargs): + """ + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Name of this function adheres to Python 3.2+ reference, but should work in + all versions down to 2.6. + + """ + nose = import_nose() + + global assert_raises_regex_impl + if assert_raises_regex_impl is None: + try: + # Python 3.2+ + assert_raises_regex_impl = nose.tools.assert_raises_regex + except AttributeError: + try: + # 2.7+ + assert_raises_regex_impl = nose.tools.assert_raises_regexp + except AttributeError: + # 2.6 + + # This class is copied from Python2.7 stdlib almost verbatim + class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __init__(self, expected, expected_regexp=None): + self.expected = expected + self.expected_regexp = expected_regexp + + def failureException(self, msg): + return AssertionError(msg) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + raise self.failureException( + "{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value # store for later retrieval + if self.expected_regexp is None: + return True + + expected_regexp = self.expected_regexp + if isinstance(expected_regexp, basestring): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException( + '"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + + def impl(cls, regex, callable_obj, *a, **kw): + mgr = _AssertRaisesContext(cls, regex) + if callable_obj is None: + return mgr + with mgr: + callable_obj(*a, **kw) + assert_raises_regex_impl = impl + + return assert_raises_regex_impl(exception_class, expected_regexp, + callable_obj, *args, **kwargs) + + def decorate_methods(cls, decorator, testmatch=None): """ Apply a decorator to all methods in a class matching a regular expression. @@ -1147,6 +1249,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, It compares the difference between `actual` and `desired` to ``atol + rtol * abs(desired)``. + .. versionadded:: 1.5.0 + Parameters ---------- actual : array_like @@ -1231,7 +1335,6 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) - ------------------------------------------------------------ Traceback (most recent call last): ... AssertionError: X and Y are not equal to 1 ULP (max is 2) @@ -1456,6 +1559,7 @@ class WarningManager(object): self._module.filters = self._filters self._module.showwarning = self._showwarning + def assert_warns(warning_class, func, *args, **kw): """ Fail unless the given callable throws the specified warning. @@ -1465,6 +1569,8 @@ def assert_warns(warning_class, func, *args, **kw): If a different type of warning is thrown, it will not be caught, and the test case will be deemed to have suffered an error. + .. versionadded:: 1.4.0 + Parameters ---------- warning_class : class @@ -1496,6 +1602,8 @@ def assert_no_warnings(func, *args, **kw): """ Fail if the given callable produces any warnings. + .. versionadded:: 1.7.0 + Parameters ---------- func : callable @@ -1587,3 +1695,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24): class IgnoreException(Exception): "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + yield tmpdir + shutil.rmtree(tmpdir) |