diff options
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 266 |
1 files changed, 211 insertions, 55 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 8a31fcf15..4569efa91 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -2,8 +2,6 @@ Utility function to facilitate testing. """ -from __future__ import division, absolute_import, print_function - import os import sys import platform @@ -21,11 +19,9 @@ import pprint from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) +import numpy.linalg.lapack_lite -if sys.version_info[0] >= 3: - from io import StringIO -else: - from StringIO import StringIO +from io import StringIO __all__ = [ 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', @@ -39,7 +35,7 @@ __all__ = [ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare', '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles', - 'break_cycles', + 'break_cycles', 'HAS_LAPACK64' ] @@ -53,6 +49,7 @@ verbose = 0 IS_PYPY = platform.python_implementation() == 'PyPy' HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None +HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64 def import_nose(): @@ -195,9 +192,8 @@ elif sys.platform[:5] == 'linux': """ try: - f = open(_proc_pid_stat, 'r') - l = f.readline().split(' ') - f.close() + with open(_proc_pid_stat, 'r') as f: + l = f.readline().split(' ') return int(l[22]) except Exception: return @@ -224,9 +220,8 @@ if sys.platform[:5] == 'linux': if not _load_time: _load_time.append(time.time()) try: - f = open(_proc_pid_stat, 'r') - l = f.readline().split(' ') - f.close() + with open(_proc_pid_stat, 'r') as f: + l = f.readline().split(' ') return int(l[13]) except Exception: return int(100*(time.time()-_load_time[0])) @@ -284,8 +279,12 @@ def assert_equal(actual, desired, err_msg='', verbose=True): check that all elements of these objects are equal. An exception is raised at the first conflicting values. + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + This function handles NaN comparisons as if NaN was a "normal" number. - That is, no assertion is raised if both objects have NaNs in the same + That is, AssertionError is not raised if both objects have NaNs in the same positions. This is in contrast to the IEEE standard on NaNs, which says that NaN compared to anything must return False. @@ -374,21 +373,6 @@ def assert_equal(actual, desired, err_msg='', verbose=True): if isscalar(desired) != isscalar(actual): raise AssertionError(msg) - # Inf/nan/negative zero handling - try: - isdesnan = gisnan(desired) - isactnan = gisnan(actual) - if isdesnan and isactnan: - return # both nan, so equal - - # handle signed zero specially for floats - if desired == 0 and actual == 0: - if not signbit(desired) == signbit(actual): - raise AssertionError(msg) - - except (TypeError, ValueError, NotImplementedError): - pass - try: isdesnat = isnat(desired) isactnat = isnat(actual) @@ -404,6 +388,33 @@ def assert_equal(actual, desired, err_msg='', verbose=True): except (TypeError, ValueError, NotImplementedError): pass + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + # handle signed zero specially for floats + array_actual = array(actual) + array_desired = array(desired) + if (array_actual.dtype.char in 'Mm' or + array_desired.dtype.char in 'Mm'): + # version 1.18 + # until this version, gisnan failed for datetime64 and timedelta64. + # Now it succeeds but comparison to scalar with a different type + # emits a DeprecationWarning. + # Avoid that by skipping the next check + raise NotImplementedError('cannot compare to a scalar ' + 'with a different type') + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + try: # Explicitly use __eq__ for comparison, gh-2552 if not (desired == actual): @@ -519,7 +530,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): ... AssertionError: Arrays are not almost equal to 9 decimals - Mismatch: 50% + <BLANKLINE> + Mismatched elements: 1 / 2 (50%) Max absolute difference: 6.66669964e-09 Max relative difference: 2.85715698e-09 x: array([1. , 2.333333333]) @@ -841,10 +853,11 @@ def assert_array_equal(x, y, err_msg='', verbose=True): Raises an AssertionError if two array_like objects are not equal. Given two array_like objects, check that the shape is equal and all - elements of these objects are equal. An exception is raised at - shape mismatch or conflicting values. In contrast to the standard usage - in numpy, NaNs are compared like numbers, no assertion is raised if - both objects have NaNs in the same positions. + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. The usual caution for verifying equality with floating point numbers is advised. @@ -871,6 +884,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True): relative and/or absolute precision. assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. + Examples -------- The first assert does not raise an exception: @@ -878,7 +897,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], ... [np.exp(0),2.33333, np.nan]) - Assert fails with numerical inprecision with floats: + Assert fails with numerical imprecision with floats: >>> np.testing.assert_array_equal([1.0,np.pi,np.nan], ... [1, np.sqrt(np.pi)**2, np.nan]) @@ -886,7 +905,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True): ... AssertionError: Arrays are not equal - Mismatch: 33.3% + <BLANKLINE> + Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 4.4408921e-16 Max relative difference: 1.41357986e-16 x: array([1. , 3.141593, nan]) @@ -899,6 +919,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True): ... [1, np.sqrt(np.pi)**2, np.nan], ... rtol=1e-10, atol=0) + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, @@ -963,7 +989,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): ... AssertionError: Arrays are not almost equal to 5 decimals - Mismatch: 33.3% + <BLANKLINE> + Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 6.e-05 Max relative difference: 2.57136612e-05 x: array([1. , 2.33333, nan]) @@ -975,6 +1002,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): ... AssertionError: Arrays are not almost equal to 5 decimals + <BLANKLINE> x and y nan location mismatch: x: array([1. , 2.33333, nan]) y: array([1. , 2.33333, 5. ]) @@ -1062,7 +1090,8 @@ def assert_array_less(x, y, err_msg='', verbose=True): ... AssertionError: Arrays are not less-ordered - Mismatch: 33.3% + <BLANKLINE> + Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 1. Max relative difference: 0.5 x: array([ 1., 1., nan]) @@ -1073,7 +1102,8 @@ def assert_array_less(x, y, err_msg='', verbose=True): ... AssertionError: Arrays are not less-ordered - Mismatch: 50% + <BLANKLINE> + Mismatched elements: 1 / 2 (50%) Max absolute difference: 2. Max relative difference: 0.66666667 x: array([1., 4.]) @@ -1084,6 +1114,7 @@ def assert_array_less(x, y, err_msg='', verbose=True): ... AssertionError: Arrays are not less-ordered + <BLANKLINE> (shapes (3,), (1,) mismatch) x: array([1., 2., 3.]) y: array([4]) @@ -1315,14 +1346,7 @@ def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): """ __tracebackhide__ = True # Hide traceback for py.test - - if sys.version_info.major >= 3: - funcname = _d.assertRaisesRegex - else: - # Only present in Python 2.7, missing from unittest in 2.6 - funcname = _d.assertRaisesRegexp - - return funcname(exception_class, expected_regexp, *args, **kwargs) + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) def decorate_methods(cls, decorator, testmatch=None): @@ -1427,7 +1451,9 @@ def _assert_valid_refcount(op): """ if not HAS_REFCOUNT: return True - import numpy as np, gc + + import gc + import numpy as np b = np.arange(100*100).reshape(100, 100) c = b @@ -1590,6 +1616,12 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): AssertionError If one or more elements differ by more than `maxulp`. + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero. + See Also -------- assert_array_almost_equal_nulp : Compare two arrays relatively to their @@ -1605,8 +1637,9 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): import numpy as np ret = nulp_diff(a, b, dtype) if not np.all(ret <= maxulp): - raise AssertionError("Arrays are not almost equal up to %g ULP" % - maxulp) + raise AssertionError("Arrays are not almost equal up to %g " + "ULP (max difference is %g ULP)" % + (maxulp, np.max(ret))) return ret @@ -1629,6 +1662,12 @@ def nulp_diff(x, y, dtype=None): number of representable floating point numbers between each item in x and y. + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero. + Examples -------- # By definition, epsilon is the smallest number such as 1 + eps != 1, so @@ -1648,8 +1687,11 @@ def nulp_diff(x, y, dtype=None): if np.iscomplexobj(x) or np.iscomplexobj(y): raise NotImplementedError("_nulp not implemented for complex array") - x = np.array(x, dtype=t) - y = np.array(y, dtype=t) + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan if not x.shape == y.shape: raise ValueError("x and y do not have the same shape: %s - %s" % @@ -1971,7 +2013,7 @@ class clear_and_catch_warnings(warnings.catch_warnings): mod.__warningregistry__.update(self._warnreg_copies[mod]) -class suppress_warnings(object): +class suppress_warnings: """ Context manager and decorator doing much the same as ``warnings.catch_warnings``. @@ -2186,8 +2228,7 @@ class suppress_warnings(object): del self._filters def _showwarning(self, message, category, filename, lineno, - *args, **kwargs): - use_warnmsg = kwargs.pop("use_warnmsg", None) + *args, use_warnmsg=None, **kwargs): for cat, _, pattern, mod, rec in ( self._suppressions + self._tmp_suppressions)[::-1]: if (issubclass(category, cat) and @@ -2351,3 +2392,118 @@ def break_cycles(): gc.collect() # one more, just to make sure gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = 'NPY_AVAILABLE_MEM' + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError('Invalid environment variable {}: {!s}'.format( + env_var, exc)) + + msg = ('{0} GB memory required, but environment variable ' + 'NPY_AVAILABLE_MEM={1} set'.format( + free_bytes/1e9, env_value)) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test.") + mem_free = -1 + else: + msg = '{0} GB memory required, but {1} GB available'.format( + free_bytes/1e9, mem_free/1e9) + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = {'': 1, 'b': 1, + 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, + 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4, + 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4} + + size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format( + '|'.join(suffixes.keys())), re.I) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError("value {!r} not a valid size".format(size_str)) + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith('linux'): + info = {} + with open('/proc/meminfo', 'r') as f: + for line in f: + p = line.split() + info[p[0].strip(':').lower()] = int(p[1]) * 1024 + + if 'memavailable' in info: + # Linux >= 3.14 + return info['memavailable'] + else: + return info['memfree'] + info['cached'] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, 'gettrace'): + return func + else: + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + return wrapper + |