summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/__init__.py1
-rw-r--r--numpy/testing/nosetester.py49
-rw-r--r--numpy/testing/numpytest.py36
-rw-r--r--numpy/testing/tests/test_utils.py76
-rw-r--r--numpy/testing/utils.py149
5 files changed, 258 insertions, 53 deletions
diff --git a/numpy/testing/__init__.py b/numpy/testing/__init__.py
index aad448c96..258cbe928 100644
--- a/numpy/testing/__init__.py
+++ b/numpy/testing/__init__.py
@@ -11,7 +11,6 @@ from unittest import TestCase
from . import decorators as dec
from .utils import *
-from .numpytest import importall # remove for numpy 1.9.0
from .nosetester import NoseTester as Tester
from .nosetester import run_module_suite
test = Tester().test
diff --git a/numpy/testing/nosetester.py b/numpy/testing/nosetester.py
index 77ac61074..a06d559e7 100644
--- a/numpy/testing/nosetester.py
+++ b/numpy/testing/nosetester.py
@@ -75,14 +75,55 @@ def import_nose():
return nose
-def run_module_suite(file_to_run = None):
+def run_module_suite(file_to_run=None, argv=None):
+ """
+ Run a test module.
+
+ Equivalent to calling ``$ nosetests <argv> <file_to_run>`` from
+ the command line
+
+ Parameters
+ ----------
+ file_to_run: str, optional
+ Path to test module, or None.
+ By default, run the module from which this function is called.
+ argv: list of strings
+ Arguments to be passed to the nose test runner. ``argv[0]`` is
+ ignored. All command line arguments accepted by ``nosetests``
+ will work.
+
+ .. versionadded:: 1.9.0
+
+ Examples
+ --------
+ Adding the following::
+
+ if __name__ == "__main__" :
+ run_module_suite(argv=sys.argv)
+
+ at the end of a test module will run the tests when that module is
+ called in the python interpreter.
+
+ Alternatively, calling::
+
+ >>> run_module_suite(file_to_run="numpy/tests/test_matlib.py")
+
+ from an interpreter will run all the test routine in 'test_matlib.py'.
+ """
if file_to_run is None:
f = sys._getframe(1)
file_to_run = f.f_locals.get('__file__', None)
if file_to_run is None:
raise AssertionError
- import_nose().run(argv=['', file_to_run])
+ if argv is None:
+ argv = ['', file_to_run]
+ else:
+ argv = argv + [file_to_run]
+
+ nose = import_nose()
+ from .noseclasses import KnownFailure
+ nose.run(argv=argv, addplugins=[KnownFailure()])
class NoseTester(object):
@@ -378,6 +419,10 @@ class NoseTester(object):
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
warnings.filterwarnings("ignore", category=ModuleDeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
+ # Filter out boolean '-' deprecation messages. This allows
+ # older versions of scipy to test without a flood of messages.
+ warnings.filterwarnings("ignore", message=".*boolean negative.*")
+ warnings.filterwarnings("ignore", message=".*boolean subtract.*")
from .noseclasses import NumpyTestProgram
diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py
deleted file mode 100644
index d83a21d83..000000000
--- a/numpy/testing/numpytest.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from __future__ import division, absolute_import, print_function
-
-import os
-import warnings
-
-__all__ = ['importall']
-
-
-def importall(package):
- """
- `importall` is DEPRECATED and will be removed in numpy 1.9.0
-
- Try recursively to import all subpackages under package.
- """
- warnings.warn("`importall is deprecated, and will be remobed in numpy 1.9.0",
- DeprecationWarning)
-
- if isinstance(package, str):
- package = __import__(package)
-
- package_name = package.__name__
- package_dir = os.path.dirname(package.__file__)
- for subpackage_name in os.listdir(package_dir):
- subdir = os.path.join(package_dir, subpackage_name)
- if not os.path.isdir(subdir):
- continue
- if not os.path.isfile(os.path.join(subdir, '__init__.py')):
- continue
- name = package_name+'.'+subpackage_name
- try:
- exec('import %s as m' % (name))
- except Exception as msg:
- print('Failed importing %s: %s' %(name, msg))
- continue
- importall(m)
- return
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 94fc4d655..aa0a2669f 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -53,6 +53,9 @@ class _GenericTest(object):
a = np.array([1, 1], dtype=np.object)
self._test_equal(a, 1)
+ def test_array_likes(self):
+ self._test_equal([1, 2, 3], (1, 2, 3))
+
class TestArrayEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_array_equal
@@ -131,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(c, b)
+class TestBuildErrorMessage(unittest.TestCase):
+ def test_build_err_msg_defaults(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, '
+ '2.00003, 3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_no_verbose(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, verbose=False)
+ b = '\nItems are not equal: There is a mismatch'
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_names(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
+ b = ('\nItems are not equal: There is a mismatch\n FOO: array([ '
+ '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, '
+ '3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_precision(self):
+ x = np.array([1.000000001, 2.00002, 3.00003])
+ y = np.array([1.000000002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, precision=10)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ '
+ '1.000000002, 2.00003 , 3.00004 ])')
+ self.assertEqual(a, b)
+
class TestEqual(TestArrayEqual):
def setUp(self):
self._assert_func = assert_equal
@@ -236,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(x, y)
self._test_not_equal(x, z)
+ def test_error_message(self):
+ """Check the message is formatted correctly for the decimal value"""
+ x = np.array([1.00000000001, 2.00000000002, 3.00003])
+ y = np.array([1.00000000002, 2.00000000003, 3.00004])
+
+ # test with a different amount of decimal digits
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 '
+ ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])')
+ try:
+ self._assert_func(x, y, decimal=12)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
+ # with the default value of decimal digits, only the 3rd element differs
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , '
+ '2. , 3.00004])')
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
class TestApproxEqual(unittest.TestCase):
def setUp(self):
self._assert_func = assert_approx_equal
@@ -373,6 +444,11 @@ class TestAssertAllclose(unittest.TestCase):
assert_allclose(6, 10, rtol=0.5)
self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
+ def test_min_int(self):
+ a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
+ # Should not raise:
+ assert_allclose(a, a)
+
class TestArrayAlmostEqualNulp(unittest.TestCase):
@dec.knownfailureif(True, "Github issue #347")
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 2a99fe5cb..13c3e4610 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -9,8 +9,12 @@ import sys
import re
import operator
import warnings
+from functools import partial
+import shutil
+import contextlib
+from tempfile import mkdtemp
from .nosetester import import_nose
-from numpy.core import float32, empty, arange
+from numpy.core import float32, empty, arange, array_repr, ndarray
if sys.version_info[0] >= 3:
from io import StringIO
@@ -22,7 +26,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
- 'assert_', 'assert_array_almost_equal_nulp',
+ 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
'assert_allclose', 'IgnoreException']
@@ -190,8 +194,7 @@ if os.name=='nt' and sys.version[:3] > '2.3':
win32pdh.PDH_FMT_LONG, None)
def build_err_msg(arrays, err_msg, header='Items are not equal:',
- verbose=True,
- names=('ACTUAL', 'DESIRED')):
+ verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
@@ -200,8 +203,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
msg.append(err_msg)
if verbose:
for i, a in enumerate(arrays):
+
+ if isinstance(a, ndarray):
+ # precision argument is only needed if the objects are ndarrays
+ r_func = partial(array_repr, precision=precision)
+ else:
+ r_func = repr
+
try:
- r = repr(a)
+ r = r_func(a)
except:
r = '[repr failed]'
if r.count('\n') > 3:
@@ -318,7 +328,9 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
# as before
except (TypeError, ValueError, NotImplementedError):
pass
- if desired != actual :
+
+ # Explicitly use __eq__ for comparison, ticket #2552
+ if not (desired == actual):
raise AssertionError(msg)
def print_assert_equal(test_string, actual, desired):
@@ -573,7 +585,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header=''):
+ header='', precision=6):
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -590,7 +602,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:' \
% (hasval), verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise AssertionError(msg)
try:
@@ -601,7 +613,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
+ '\n(shapes %s, %s mismatch)' % (x.shape,
y.shape),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
@@ -646,7 +658,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
err_msg
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
except ValueError as e:
@@ -655,7 +667,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise ValueError(msg)
def assert_array_equal(x, y, err_msg='', verbose=True):
@@ -793,7 +805,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y: array([ 1. , 2.33333, 5. ])
"""
- from numpy.core import around, number, float_
+ from numpy.core import around, number, float_, result_type, array
from numpy.core.numerictypes import issubdtype
from numpy.core.fromnumeric import any as npany
def compare(x, y):
@@ -810,12 +822,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y = y[~yinfid]
except (TypeError, NotImplementedError):
pass
+
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later.
+ dtype = result_type(y, 1.)
+ y = array(y, dtype=dtype, copy=False)
z = abs(x-y)
+
if not issubdtype(z.dtype, number):
z = z.astype(float_) # handle object arrays
+
return around(z, decimal) <= 10.0**(-decimal)
+
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
- header=('Arrays are not almost equal to %d decimals' % decimal))
+ header=('Arrays are not almost equal to %d decimals' % decimal),
+ precision=decimal)
+
def assert_array_less(x, y, err_msg='', verbose=True):
"""
@@ -1011,6 +1033,7 @@ def raises(*args,**kwargs):
nose = import_nose()
return nose.tools.raises(*args,**kwargs)
+
def assert_raises(*args,**kwargs):
"""
assert_raises(exception_class, callable, *args, **kwargs)
@@ -1026,6 +1049,85 @@ def assert_raises(*args,**kwargs):
nose = import_nose()
return nose.tools.assert_raises(*args,**kwargs)
+
+assert_raises_regex_impl = None
+
+
+def assert_raises_regex(exception_class, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """
+ Fail unless an exception of class exception_class and with message that
+ matches expected_regexp is thrown by callable when invoked with arguments
+ args and keyword arguments kwargs.
+
+ Name of this function adheres to Python 3.2+ reference, but should work in
+ all versions down to 2.6.
+
+ """
+ nose = import_nose()
+
+ global assert_raises_regex_impl
+ if assert_raises_regex_impl is None:
+ try:
+ # Python 3.2+
+ assert_raises_regex_impl = nose.tools.assert_raises_regex
+ except AttributeError:
+ try:
+ # 2.7+
+ assert_raises_regex_impl = nose.tools.assert_raises_regexp
+ except AttributeError:
+ # 2.6
+
+ # This class is copied from Python2.7 stdlib almost verbatim
+ class _AssertRaisesContext(object):
+ """A context manager used to implement TestCase.assertRaises* methods."""
+
+ def __init__(self, expected, expected_regexp=None):
+ self.expected = expected
+ self.expected_regexp = expected_regexp
+
+ def failureException(self, msg):
+ return AssertionError(msg)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is None:
+ try:
+ exc_name = self.expected.__name__
+ except AttributeError:
+ exc_name = str(self.expected)
+ raise self.failureException(
+ "{0} not raised".format(exc_name))
+ if not issubclass(exc_type, self.expected):
+ # let unexpected exceptions pass through
+ return False
+ self.exception = exc_value # store for later retrieval
+ if self.expected_regexp is None:
+ return True
+
+ expected_regexp = self.expected_regexp
+ if isinstance(expected_regexp, basestring):
+ expected_regexp = re.compile(expected_regexp)
+ if not expected_regexp.search(str(exc_value)):
+ raise self.failureException(
+ '"%s" does not match "%s"' %
+ (expected_regexp.pattern, str(exc_value)))
+ return True
+
+ def impl(cls, regex, callable_obj, *a, **kw):
+ mgr = _AssertRaisesContext(cls, regex)
+ if callable_obj is None:
+ return mgr
+ with mgr:
+ callable_obj(*a, **kw)
+ assert_raises_regex_impl = impl
+
+ return assert_raises_regex_impl(exception_class, expected_regexp,
+ callable_obj, *args, **kwargs)
+
+
def decorate_methods(cls, decorator, testmatch=None):
"""
Apply a decorator to all methods in a class matching a regular expression.
@@ -1147,6 +1249,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0,
It compares the difference between `actual` and `desired` to
``atol + rtol * abs(desired)``.
+ .. versionadded:: 1.5.0
+
Parameters
----------
actual : array_like
@@ -1231,7 +1335,6 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
- ------------------------------------------------------------
Traceback (most recent call last):
...
AssertionError: X and Y are not equal to 1 ULP (max is 2)
@@ -1456,6 +1559,7 @@ class WarningManager(object):
self._module.filters = self._filters
self._module.showwarning = self._showwarning
+
def assert_warns(warning_class, func, *args, **kw):
"""
Fail unless the given callable throws the specified warning.
@@ -1465,6 +1569,8 @@ def assert_warns(warning_class, func, *args, **kw):
If a different type of warning is thrown, it will not be caught, and the
test case will be deemed to have suffered an error.
+ .. versionadded:: 1.4.0
+
Parameters
----------
warning_class : class
@@ -1496,6 +1602,8 @@ def assert_no_warnings(func, *args, **kw):
"""
Fail if the given callable produces any warnings.
+ .. versionadded:: 1.7.0
+
Parameters
----------
func : callable
@@ -1587,3 +1695,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
class IgnoreException(Exception):
"Ignoring this exception due to disabled feature"
+
+
+@contextlib.contextmanager
+def tempdir(*args, **kwargs):
+ """Context manager to provide a temporary test folder.
+
+ All arguments are passed as this to the underlying tempfile.mkdtemp
+ function.
+
+ """
+ tmpdir = mkdtemp(*args, **kwargs)
+ yield tmpdir
+ shutil.rmtree(tmpdir)