summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py149
1 files changed, 135 insertions, 14 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 2a99fe5cb..13c3e4610 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -9,8 +9,12 @@ import sys
import re
import operator
import warnings
+from functools import partial
+import shutil
+import contextlib
+from tempfile import mkdtemp
from .nosetester import import_nose
-from numpy.core import float32, empty, arange
+from numpy.core import float32, empty, arange, array_repr, ndarray
if sys.version_info[0] >= 3:
from io import StringIO
@@ -22,7 +26,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
- 'assert_', 'assert_array_almost_equal_nulp',
+ 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
'assert_allclose', 'IgnoreException']
@@ -190,8 +194,7 @@ if os.name=='nt' and sys.version[:3] > '2.3':
win32pdh.PDH_FMT_LONG, None)
def build_err_msg(arrays, err_msg, header='Items are not equal:',
- verbose=True,
- names=('ACTUAL', 'DESIRED')):
+ verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
@@ -200,8 +203,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
msg.append(err_msg)
if verbose:
for i, a in enumerate(arrays):
+
+ if isinstance(a, ndarray):
+ # precision argument is only needed if the objects are ndarrays
+ r_func = partial(array_repr, precision=precision)
+ else:
+ r_func = repr
+
try:
- r = repr(a)
+ r = r_func(a)
except:
r = '[repr failed]'
if r.count('\n') > 3:
@@ -318,7 +328,9 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
# as before
except (TypeError, ValueError, NotImplementedError):
pass
- if desired != actual :
+
+ # Explicitly use __eq__ for comparison, ticket #2552
+ if not (desired == actual):
raise AssertionError(msg)
def print_assert_equal(test_string, actual, desired):
@@ -573,7 +585,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header=''):
+ header='', precision=6):
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -590,7 +602,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:' \
% (hasval), verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise AssertionError(msg)
try:
@@ -601,7 +613,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
+ '\n(shapes %s, %s mismatch)' % (x.shape,
y.shape),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
@@ -646,7 +658,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
err_msg
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
except ValueError as e:
@@ -655,7 +667,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise ValueError(msg)
def assert_array_equal(x, y, err_msg='', verbose=True):
@@ -793,7 +805,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y: array([ 1. , 2.33333, 5. ])
"""
- from numpy.core import around, number, float_
+ from numpy.core import around, number, float_, result_type, array
from numpy.core.numerictypes import issubdtype
from numpy.core.fromnumeric import any as npany
def compare(x, y):
@@ -810,12 +822,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y = y[~yinfid]
except (TypeError, NotImplementedError):
pass
+
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later.
+ dtype = result_type(y, 1.)
+ y = array(y, dtype=dtype, copy=False)
z = abs(x-y)
+
if not issubdtype(z.dtype, number):
z = z.astype(float_) # handle object arrays
+
return around(z, decimal) <= 10.0**(-decimal)
+
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
- header=('Arrays are not almost equal to %d decimals' % decimal))
+ header=('Arrays are not almost equal to %d decimals' % decimal),
+ precision=decimal)
+
def assert_array_less(x, y, err_msg='', verbose=True):
"""
@@ -1011,6 +1033,7 @@ def raises(*args,**kwargs):
nose = import_nose()
return nose.tools.raises(*args,**kwargs)
+
def assert_raises(*args,**kwargs):
"""
assert_raises(exception_class, callable, *args, **kwargs)
@@ -1026,6 +1049,85 @@ def assert_raises(*args,**kwargs):
nose = import_nose()
return nose.tools.assert_raises(*args,**kwargs)
+
+assert_raises_regex_impl = None
+
+
+def assert_raises_regex(exception_class, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """
+ Fail unless an exception of class exception_class and with message that
+ matches expected_regexp is thrown by callable when invoked with arguments
+ args and keyword arguments kwargs.
+
+ Name of this function adheres to Python 3.2+ reference, but should work in
+ all versions down to 2.6.
+
+ """
+ nose = import_nose()
+
+ global assert_raises_regex_impl
+ if assert_raises_regex_impl is None:
+ try:
+ # Python 3.2+
+ assert_raises_regex_impl = nose.tools.assert_raises_regex
+ except AttributeError:
+ try:
+ # 2.7+
+ assert_raises_regex_impl = nose.tools.assert_raises_regexp
+ except AttributeError:
+ # 2.6
+
+ # This class is copied from Python2.7 stdlib almost verbatim
+ class _AssertRaisesContext(object):
+ """A context manager used to implement TestCase.assertRaises* methods."""
+
+ def __init__(self, expected, expected_regexp=None):
+ self.expected = expected
+ self.expected_regexp = expected_regexp
+
+ def failureException(self, msg):
+ return AssertionError(msg)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is None:
+ try:
+ exc_name = self.expected.__name__
+ except AttributeError:
+ exc_name = str(self.expected)
+ raise self.failureException(
+ "{0} not raised".format(exc_name))
+ if not issubclass(exc_type, self.expected):
+ # let unexpected exceptions pass through
+ return False
+ self.exception = exc_value # store for later retrieval
+ if self.expected_regexp is None:
+ return True
+
+ expected_regexp = self.expected_regexp
+ if isinstance(expected_regexp, basestring):
+ expected_regexp = re.compile(expected_regexp)
+ if not expected_regexp.search(str(exc_value)):
+ raise self.failureException(
+ '"%s" does not match "%s"' %
+ (expected_regexp.pattern, str(exc_value)))
+ return True
+
+ def impl(cls, regex, callable_obj, *a, **kw):
+ mgr = _AssertRaisesContext(cls, regex)
+ if callable_obj is None:
+ return mgr
+ with mgr:
+ callable_obj(*a, **kw)
+ assert_raises_regex_impl = impl
+
+ return assert_raises_regex_impl(exception_class, expected_regexp,
+ callable_obj, *args, **kwargs)
+
+
def decorate_methods(cls, decorator, testmatch=None):
"""
Apply a decorator to all methods in a class matching a regular expression.
@@ -1147,6 +1249,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0,
It compares the difference between `actual` and `desired` to
``atol + rtol * abs(desired)``.
+ .. versionadded:: 1.5.0
+
Parameters
----------
actual : array_like
@@ -1231,7 +1335,6 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
- ------------------------------------------------------------
Traceback (most recent call last):
...
AssertionError: X and Y are not equal to 1 ULP (max is 2)
@@ -1456,6 +1559,7 @@ class WarningManager(object):
self._module.filters = self._filters
self._module.showwarning = self._showwarning
+
def assert_warns(warning_class, func, *args, **kw):
"""
Fail unless the given callable throws the specified warning.
@@ -1465,6 +1569,8 @@ def assert_warns(warning_class, func, *args, **kw):
If a different type of warning is thrown, it will not be caught, and the
test case will be deemed to have suffered an error.
+ .. versionadded:: 1.4.0
+
Parameters
----------
warning_class : class
@@ -1496,6 +1602,8 @@ def assert_no_warnings(func, *args, **kw):
"""
Fail if the given callable produces any warnings.
+ .. versionadded:: 1.7.0
+
Parameters
----------
func : callable
@@ -1587,3 +1695,16 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
class IgnoreException(Exception):
"Ignoring this exception due to disabled feature"
+
+
+@contextlib.contextmanager
+def tempdir(*args, **kwargs):
+ """Context manager to provide a temporary test folder.
+
+ All arguments are passed as this to the underlying tempfile.mkdtemp
+ function.
+
+ """
+ tmpdir = mkdtemp(*args, **kwargs)
+ yield tmpdir
+ shutil.rmtree(tmpdir)