summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py288
1 files changed, 174 insertions, 114 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index c6d863f94..176d87800 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -12,9 +12,12 @@ import warnings
from functools import partial
import shutil
import contextlib
-from tempfile import mkdtemp
+from tempfile import mkdtemp, mkstemp
+from unittest.case import SkipTest
+
from .nosetester import import_nose
from numpy.core import float32, empty, arange, array_repr, ndarray
+from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
from io import StringIO
@@ -28,11 +31,21 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
- 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings']
+ 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
+ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
+ 'HAS_REFCOUNT']
+
+
+class KnownFailureException(Exception):
+ '''Raise this exception to mark a test as a known failing test.'''
+ pass
+KnownFailureTest = KnownFailureException # backwards compat
verbose = 0
+IS_PYPY = '__pypy__' in sys.modules
+HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None
def assert_(val, msg=''):
"""
@@ -52,6 +65,7 @@ def assert_(val, msg=''):
smsg = msg
raise AssertionError(smsg)
+
def gisnan(x):
"""like isnan, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -69,6 +83,7 @@ def gisnan(x):
raise TypeError("isnan not supported for this type")
return st
+
def gisfinite(x):
"""like isfinite, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -87,6 +102,7 @@ def gisfinite(x):
raise TypeError("isfinite not supported for this type")
return st
+
def gisinf(x):
"""like isinf, but always raise an error if type not supported instead of
returning a TypeError object.
@@ -105,6 +121,9 @@ def gisinf(x):
raise TypeError("isinf not supported for this type")
return st
+
+@deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. "
+ "Use numpy.random.rand instead.")
def rand(*args):
"""Returns an array of random numbers with the given shape.
@@ -118,6 +137,7 @@ def rand(*args):
f[i] = random.random()
return results
+
if os.name == 'nt':
# Code "stolen" from enthought/debug/memusage.py
def GetPerformanceAttributes(object, counter, instance=None,
@@ -232,14 +252,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
try:
r = r_func(a)
- except:
- r = '[repr failed]'
+ except Exception as exc:
+ r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
if r.count('\n') > 3:
r = '\n'.join(r.splitlines()[:3])
r += '...'
msg.append(' %s: %s' % (names[i], r))
return '\n'.join(msg)
+
def assert_equal(actual,desired,err_msg='',verbose=True):
"""
Raises an AssertionError if two objects are not equal.
@@ -354,6 +375,7 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
if not (desired == actual):
raise AssertionError(msg)
+
def print_assert_equal(test_string, actual, desired):
"""
Test if two objects are equal, and print an error message if test fails.
@@ -394,6 +416,7 @@ def print_assert_equal(test_string, actual, desired):
pprint.pprint(desired, msg)
raise AssertionError(msg.getvalue())
+
def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
"""
Raises an AssertionError if two items are not equal up to desired
@@ -404,11 +427,14 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
+ The test verifies that the elements of ``actual`` and ``desired`` satisfy.
- Given two objects (numbers or ndarrays), check that all elements of these
- objects are almost equal. An exception is raised at conflicting values.
- For ndarrays this delegates to assert_array_almost_equal
+ ``abs(desired-actual) < 1.5 * 10**(-decimal)``
+
+ That is a looser test than originally documented, but agrees with what the
+ actual implementation in `assert_array_almost_equal` did up to rounding
+ vagaries. An exception is raised at conflicting values. For ndarrays this
+ delegates to assert_array_almost_equal
Parameters
----------
@@ -509,7 +535,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
return
except (NotImplementedError, TypeError):
pass
- if round(abs(desired - actual), decimal) != 0:
+ if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
raise AssertionError(_build_err_msg())
@@ -610,6 +636,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
raise AssertionError(msg)
+
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6):
__tracebackhide__ = True # Hide traceback for py.test
@@ -720,6 +747,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise ValueError(msg)
+
def assert_array_equal(x, y, err_msg='', verbose=True):
"""
Raises an AssertionError if two array_like objects are not equal.
@@ -786,6 +814,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
verbose=verbose, header='Arrays are not equal')
+
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
"""
Raises an AssertionError if two objects are not equal up to desired
@@ -796,14 +825,16 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test verifies identical shapes and verifies values with
- ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
+ The test verifies identical shapes and that the elements of ``actual`` and
+ ``desired`` satisfy.
- Given two array_like objects, check that the shape is equal and all
- elements of these objects are almost equal. An exception is raised at
- shape mismatch or conflicting values. In contrast to the standard usage
- in numpy, NaNs are compared like numbers, no assertion is raised if
- both objects have NaNs in the same positions.
+ ``abs(desired-actual) < 1.5 * 10**(-decimal)``
+
+ That is a looser test than originally documented, but agrees with what the
+ actual implementation did up to rounding vagaries. An exception is raised
+ at shape mismatch or conflicting values. In contrast to the standard usage
+ in numpy, NaNs are compared like numbers, no assertion is raised if both
+ objects have NaNs in the same positions.
Parameters
----------
@@ -880,12 +911,12 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
# casting of x later.
dtype = result_type(y, 1.)
y = array(y, dtype=dtype, copy=False, subok=True)
- z = abs(x-y)
+ z = abs(x - y)
if not issubdtype(z.dtype, number):
z = z.astype(float_) # handle object arrays
- return around(z, decimal) <= 10.0**(-decimal)
+ return z < 1.5 * 10.0**(-decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header=('Arrays are not almost equal to %d decimals' % decimal),
@@ -963,9 +994,11 @@ def assert_array_less(x, y, err_msg='', verbose=True):
verbose=verbose,
header='Arrays are not less-ordered')
+
def runstring(astr, dict):
exec(astr, dict)
+
def assert_string_equal(actual, desired):
"""
Test if two strings are equal.
@@ -1018,11 +1051,12 @@ def assert_string_equal(actual, desired):
if not d2.startswith('+ '):
raise AssertionError(repr(d2))
l.append(d2)
- d3 = diff.pop(0)
- if d3.startswith('? '):
- l.append(d3)
- else:
- diff.insert(0, d3)
+ if diff:
+ d3 = diff.pop(0)
+ if d3.startswith('? '):
+ l.append(d3)
+ else:
+ diff.insert(0, d3)
if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]):
continue
diff_list.extend(l)
@@ -1057,18 +1091,13 @@ def rundocs(filename=None, raise_on_error=True):
>>> np.lib.test(doctests=True) #doctest: +SKIP
"""
+ from numpy.compat import npy_load_module
import doctest
- import imp
if filename is None:
f = sys._getframe(1)
filename = f.f_globals['__file__']
name = os.path.splitext(os.path.basename(filename))[0]
- path = [os.path.dirname(filename)]
- file, pathname, description = imp.find_module(name, path)
- try:
- m = imp.load_module(name, file, pathname, description)
- finally:
- file.close()
+ m = npy_load_module(name, filename)
tests = doctest.DocTestFinder().find(m)
runner = doctest.DocTestRunner(verbose=False)
@@ -1102,15 +1131,24 @@ def assert_raises(*args,**kwargs):
deemed to have suffered an error, exactly as for an
unexpected exception.
+ Alternatively, `assert_raises` can be used as a context manager:
+
+ >>> from numpy.testing import assert_raises
+ >>> with assert_raises(ZeroDivisionError):
+ ... 1 / 0
+
+ is equivalent to
+
+ >>> def div(x, y):
+ ... return x / y
+ >>> assert_raises(ZeroDivisionError, div, 1, 0)
+
"""
__tracebackhide__ = True # Hide traceback for py.test
nose = import_nose()
return nose.tools.assert_raises(*args,**kwargs)
-assert_raises_regex_impl = None
-
-
def assert_raises_regex(exception_class, expected_regexp,
callable_obj=None, *args, **kwargs):
"""
@@ -1121,70 +1159,22 @@ def assert_raises_regex(exception_class, expected_regexp,
Name of this function adheres to Python 3.2+ reference, but should work in
all versions down to 2.6.
+ Notes
+ -----
+ .. versionadded:: 1.9.0
+
"""
__tracebackhide__ = True # Hide traceback for py.test
nose = import_nose()
- global assert_raises_regex_impl
- if assert_raises_regex_impl is None:
- try:
- # Python 3.2+
- assert_raises_regex_impl = nose.tools.assert_raises_regex
- except AttributeError:
- try:
- # 2.7+
- assert_raises_regex_impl = nose.tools.assert_raises_regexp
- except AttributeError:
- # 2.6
-
- # This class is copied from Python2.7 stdlib almost verbatim
- class _AssertRaisesContext(object):
- """A context manager used to implement TestCase.assertRaises* methods."""
-
- def __init__(self, expected, expected_regexp=None):
- self.expected = expected
- self.expected_regexp = expected_regexp
-
- def failureException(self, msg):
- return AssertionError(msg)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, tb):
- if exc_type is None:
- try:
- exc_name = self.expected.__name__
- except AttributeError:
- exc_name = str(self.expected)
- raise self.failureException(
- "{0} not raised".format(exc_name))
- if not issubclass(exc_type, self.expected):
- # let unexpected exceptions pass through
- return False
- self.exception = exc_value # store for later retrieval
- if self.expected_regexp is None:
- return True
-
- expected_regexp = self.expected_regexp
- if isinstance(expected_regexp, basestring):
- expected_regexp = re.compile(expected_regexp)
- if not expected_regexp.search(str(exc_value)):
- raise self.failureException(
- '"%s" does not match "%s"' %
- (expected_regexp.pattern, str(exc_value)))
- return True
-
- def impl(cls, regex, callable_obj, *a, **kw):
- mgr = _AssertRaisesContext(cls, regex)
- if callable_obj is None:
- return mgr
- with mgr:
- callable_obj(*a, **kw)
- assert_raises_regex_impl = impl
-
- return assert_raises_regex_impl(exception_class, expected_regexp,
- callable_obj, *args, **kwargs)
+ if sys.version_info.major >= 3:
+ funcname = nose.tools.assert_raises_regex
+ else:
+ # Only present in Python 2.7, missing from unittest in 2.6
+ funcname = nose.tools.assert_raises_regexp
+
+ return funcname(exception_class, expected_regexp, callable_obj,
+ *args, **kwargs)
def decorate_methods(cls, decorator, testmatch=None):
@@ -1263,7 +1253,7 @@ def measure(code_str,times=1,label=None):
--------
>>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)',
... times=times)
- >>> print "Time for a single execution : ", etime / times, "s"
+ >>> print("Time for a single execution : ", etime / times, "s")
Time for a single execution : 0.005 s
"""
@@ -1287,6 +1277,8 @@ def _assert_valid_refcount(op):
Check that ufuncs don't mishandle refcount of object `1`.
Used in a few regression tests.
"""
+ if not HAS_REFCOUNT:
+ return True
import numpy as np
b = np.arange(100*100).reshape(100, 100)
@@ -1357,6 +1349,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False,
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
verbose=verbose, header=header)
+
def assert_array_almost_equal_nulp(x, y, nulp=1):
"""
Compare two arrays relatively to their spacing.
@@ -1419,6 +1412,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
raise AssertionError(msg)
+
def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
"""
Check that all items of arrays differ in at most N Units in the Last Place.
@@ -1463,6 +1457,7 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
maxulp)
return ret
+
def nulp_diff(x, y, dtype=None):
"""For each item in x and y, return the number of representable floating
points between them.
@@ -1516,6 +1511,7 @@ def nulp_diff(x, y, dtype=None):
ry = integer_repr(y)
return _diff(rx, ry, t)
+
def _integer_repr(x, vdt, comp):
# Reinterpret binary representation of the float as sign-magnitude:
# take into account two-complement representation
@@ -1530,6 +1526,7 @@ def _integer_repr(x, vdt, comp):
return rx
+
def integer_repr(x):
"""Return the signed-magnitude interpretation of the binary representation of
x."""
@@ -1541,6 +1538,7 @@ def integer_repr(x):
else:
raise ValueError("Unsupported dtype %s" % x.dtype)
+
# The following two classes are copied from python 2.6 warnings module (context
# manager)
class WarningMessage(object):
@@ -1575,6 +1573,7 @@ class WarningMessage(object):
"line : %r}" % (self.message, self._category_name,
self.filename, self.lineno, self.line))
+
class WarningManager(object):
"""
A context manager that copies and restores the warnings filter upon
@@ -1632,7 +1631,22 @@ class WarningManager(object):
self._module.showwarning = self._showwarning
-def assert_warns(warning_class, func, *args, **kw):
+@contextlib.contextmanager
+def _assert_warns_context(warning_class, name=None):
+ __tracebackhide__ = True # Hide traceback for py.test
+ with warnings.catch_warnings(record=True) as l:
+ warnings.simplefilter('always')
+ yield
+ if not len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("No warning raised" + name_str)
+ if not l[0].category is warning_class:
+ name_str = "%s " % name if name is not None else ""
+ raise AssertionError("First warning %sis not a %s (is %s)"
+ % (name_str, warning_class, l[0]))
+
+
+def assert_warns(warning_class, *args, **kwargs):
"""
Fail unless the given callable throws the specified warning.
@@ -1641,6 +1655,14 @@ def assert_warns(warning_class, func, *args, **kw):
If a different type of warning is thrown, it will not be caught, and the
test case will be deemed to have suffered an error.
+ If called with all arguments other than the warning class omitted, may be
+ used as a context manager:
+
+ with assert_warns(SomeWarning):
+ do_something()
+
+ The ability to be used as a context manager is new in NumPy v1.11.0.
+
.. versionadded:: 1.4.0
Parameters
@@ -1659,22 +1681,37 @@ def assert_warns(warning_class, func, *args, **kw):
The value returned by `func`.
"""
+ if not args:
+ return _assert_warns_context(warning_class)
+
+ func = args[0]
+ args = args[1:]
+ with _assert_warns_context(warning_class, name=func.__name__):
+ return func(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def _assert_no_warnings_context(name=None):
__tracebackhide__ = True # Hide traceback for py.test
with warnings.catch_warnings(record=True) as l:
warnings.simplefilter('always')
- result = func(*args, **kw)
- if not len(l) > 0:
- raise AssertionError("No warning raised when calling %s"
- % func.__name__)
- if not l[0].category is warning_class:
- raise AssertionError("First warning for %s is not a "
- "%s( is %s)" % (func.__name__, warning_class, l[0]))
- return result
+ yield
+ if len(l) > 0:
+ name_str = " when calling %s" % name if name is not None else ""
+ raise AssertionError("Got warnings%s: %s" % (name_str, l))
+
-def assert_no_warnings(func, *args, **kw):
+def assert_no_warnings(*args, **kwargs):
"""
Fail if the given callable produces any warnings.
+ If called with all arguments omitted, may be used as a context manager:
+
+ with assert_no_warnings():
+ do_something()
+
+ The ability to be used as a context manager is new in NumPy v1.11.0.
+
.. versionadded:: 1.7.0
Parameters
@@ -1691,14 +1728,13 @@ def assert_no_warnings(func, *args, **kw):
The value returned by `func`.
"""
- __tracebackhide__ = True # Hide traceback for py.test
- with warnings.catch_warnings(record=True) as l:
- warnings.simplefilter('always')
- result = func(*args, **kw)
- if len(l) > 0:
- raise AssertionError("Got warnings when calling %s: %s"
- % (func.__name__, l))
- return result
+ if not args:
+ return _assert_no_warnings_context()
+
+ func = args[0]
+ args = args[1:]
+ with _assert_no_warnings_context(name=func.__name__):
+ return func(*args, **kwargs)
def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
@@ -1780,8 +1816,32 @@ def tempdir(*args, **kwargs):
"""
tmpdir = mkdtemp(*args, **kwargs)
- yield tmpdir
- shutil.rmtree(tmpdir)
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir)
+
+
+@contextlib.contextmanager
+def temppath(*args, **kwargs):
+ """Context manager for temporary files.
+
+ Context manager that returns the path to a closed temporary file. Its
+ parameters are the same as for tempfile.mkstemp and are passed directly
+ to that function. The underlying file is removed when the context is
+ exited, so it should be closed at that time.
+
+ Windows does not allow a temporary file to be opened if it is already
+ open, so the underlying file must be closed after opening before it
+ can be opened again.
+
+ """
+ fd, path = mkstemp(*args, **kwargs)
+ os.close(fd)
+ try:
+ yield path
+ finally:
+ os.remove(path)
class clear_and_catch_warnings(warnings.catch_warnings):