diff options
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 131 |
1 files changed, 65 insertions, 66 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 3827b7505..393fedc27 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -17,6 +17,7 @@ from unittest.case import SkipTest from warnings import WarningMessage import pprint +import numpy as np from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) import numpy.linalg.lapack_lite @@ -113,14 +114,14 @@ def gisnan(x): def gisfinite(x): - """like isfinite, but always raise an error if type not supported instead of - returning a TypeError object. + """like isfinite, but always raise an error if type not supported instead + of returning a TypeError object. Notes ----- - isfinite and other ufunc sometimes return a NotImplementedType object instead - of raising any exception. This function is a wrapper to make sure an - exception is always raised. + isfinite and other ufunc sometimes return a NotImplementedType object + instead of raising any exception. This function is a wrapper to make sure + an exception is always raised. This should be removed once this problem is solved at the Ufunc level.""" from numpy.core import isfinite, errstate @@ -160,12 +161,13 @@ if os.name == 'nt': # you should copy this function, but keep the counter open, and call # CollectQueryData() each time you need to know. # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link) - # My older explanation for this was that the "AddCounter" process forced - # the CPU to 100%, but the above makes more sense :) + # My older explanation for this was that the "AddCounter" process + # forced the CPU to 100%, but the above makes more sense :) import win32pdh if format is None: format = win32pdh.PDH_FMT_LONG - path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter)) + path = win32pdh.MakeCounterPath( (machine, object, instance, None, + inum, counter)) hq = win32pdh.OpenQuery() try: hc = win32pdh.AddCounter(hq, path) @@ -186,7 +188,7 @@ if os.name == 'nt': win32pdh.PDH_FMT_LONG, None) elif sys.platform[:5] == 'linux': - def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())): + def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'): """ Return virtual memory size in bytes of the running python. @@ -207,8 +209,7 @@ else: if sys.platform[:5] == 'linux': - def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()), - _load_time=[]): + def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]): """ Return number of jiffies elapsed. @@ -263,11 +264,11 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', try: r = r_func(a) except Exception as exc: - r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc) + r = f'[repr failed for <{type(a).__name__}>: {exc}]' if r.count('\n') > 3: r = '\n'.join(r.splitlines()[:3]) r += '...' - msg.append(' %s: %s' % (names[i], r)) + msg.append(f' {names[i]}: {r}') return '\n'.join(msg) @@ -329,12 +330,14 @@ def assert_equal(actual, desired, err_msg='', verbose=True): for k, i in desired.items(): if k not in actual: raise AssertionError(repr(k)) - assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose) + assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}', + verbose) return if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): assert_equal(len(actual), len(desired), err_msg, verbose) for k in range(len(desired)): - assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose) + assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}', + verbose) return from numpy.core import ndarray, isscalar, signbit from numpy.lib import iscomplexobj, real, imag @@ -376,7 +379,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): try: isdesnat = isnat(desired) isactnat = isnat(actual) - dtypes_match = array(desired).dtype.type == array(actual).dtype.type + dtypes_match = (np.asarray(desired).dtype.type == + np.asarray(actual).dtype.type) if isdesnat and isactnat: # If both are NaT (and have the same dtype -- datetime or # timedelta) they are considered equal. @@ -396,8 +400,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): return # both nan, so equal # handle signed zero specially for floats - array_actual = array(actual) - array_desired = array(desired) + array_actual = np.asarray(actual) + array_desired = np.asarray(desired) if (array_actual.dtype.char in 'Mm' or array_desired.dtype.char in 'Mm'): # version 1.18 @@ -479,7 +483,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): instead of this function for more consistent floating point comparisons. - The test verifies that the elements of ``actual`` and ``desired`` satisfy. + The test verifies that the elements of `actual` and `desired` satisfy. ``abs(desired-actual) < 1.5 * 10**(-decimal)`` @@ -514,9 +518,9 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): Examples -------- - >>> import numpy.testing as npt - >>> npt.assert_almost_equal(2.3333333333333, 2.33333334) - >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + >>> from numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) Traceback (most recent call last): ... AssertionError: @@ -524,8 +528,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): ACTUAL: 2.3333333333333 DESIRED: 2.33333334 - >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]), - ... np.array([1.0,2.33333334]), decimal=9) + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) Traceback (most recent call last): ... AssertionError: @@ -694,14 +698,13 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) -def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header='', precision=6, equal_nan=True, - equal_inf=True): +def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', + precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ - x = array(x, copy=False, subok=True) - y = array(y, copy=False, subok=True) + x = np.asanyarray(x) + y = np.asanyarray(y) # original array for output formatting ox, oy = x, y @@ -744,7 +747,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, # flag as it everywhere, so we should return the scalar flag. if isinstance(x_id, bool) or x_id.ndim == 0: return bool_(x_id) - elif isinstance(x_id, bool) or y_id.ndim == 0: + elif isinstance(y_id, bool) or y_id.ndim == 0: return bool_(y_id) else: return y_id @@ -754,8 +757,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, if not cond: msg = build_err_msg([x, y], err_msg - + '\n(shapes %s, %s mismatch)' % (x.shape, - y.shape), + + f'\n(shapes {x.shape}, {y.shape} mismatch)', verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) @@ -843,7 +845,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, except ValueError: import traceback efmt = traceback.format_exc() - header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) + header = f'error during assertion:\n\n{efmt}\n\n{header}' msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, names=('x', 'y'), precision=precision) @@ -1033,7 +1035,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): # make sure y is an inexact type to avoid abs(MIN_INT); will cause # casting of x later. dtype = result_type(y, 1.) - y = array(y, dtype=dtype, copy=False, subok=True) + y = np.asanyarray(y, dtype) z = abs(x - y) if not issubdtype(z.dtype, number): @@ -1170,7 +1172,8 @@ def assert_string_equal(actual, desired): if desired == actual: return - diff = list(difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True))) + diff = list(difflib.Differ().compare(actual.splitlines(True), + desired.splitlines(True))) diff_list = [] while diff: d1 = diff.pop(0) @@ -1198,7 +1201,7 @@ def assert_string_equal(actual, desired): raise AssertionError(repr(d1)) if not diff_list: return - msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip() + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" if actual != desired: raise AssertionError(msg) @@ -1434,9 +1437,7 @@ def measure(code_str, times=1, label=None): frame = sys._getframe(1) locs, globs = frame.f_locals, frame.f_globals - code = compile(code_str, - 'Test name: %s ' % label, - 'exec') + code = compile(code_str, f'Test name: {label} ', 'exec') i = 0 elapsed = jiffies() while i < times: @@ -1525,7 +1526,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, equal_nan=equal_nan) actual, desired = np.asanyarray(actual), np.asanyarray(desired) - header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) + header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' assert_array_compare(compare, actual, desired, err_msg=str(err_msg), verbose=verbose, header=header, equal_nan=equal_nan) @@ -1679,11 +1680,11 @@ def nulp_diff(x, y, dtype=None): """ import numpy as np if dtype: - x = np.array(x, dtype=dtype) - y = np.array(y, dtype=dtype) + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) else: - x = np.array(x) - y = np.array(y) + x = np.asarray(x) + y = np.asarray(y) t = np.common_type(x, y) if np.iscomplexobj(x) or np.iscomplexobj(y): @@ -1700,7 +1701,7 @@ def nulp_diff(x, y, dtype=None): (x.shape, y.shape)) def _diff(rx, ry, vdt): - diff = np.array(rx-ry, dtype=vdt) + diff = np.asarray(rx-ry, dtype=vdt) return np.abs(diff) rx = integer_repr(x) @@ -1724,8 +1725,8 @@ def _integer_repr(x, vdt, comp): def integer_repr(x): - """Return the signed-magnitude interpretation of the binary representation of - x.""" + """Return the signed-magnitude interpretation of the binary representation + of x.""" import numpy as np if x.dtype == np.float16: return _integer_repr(x, np.int16, np.int16(-2**15)) @@ -1734,7 +1735,7 @@ def integer_repr(x): elif x.dtype == np.float64: return _integer_repr(x, np.int64, np.int64(-2**63)) else: - raise ValueError("Unsupported dtype %s" % x.dtype) + raise ValueError(f'Unsupported dtype {x.dtype}') @contextlib.contextmanager @@ -1744,7 +1745,7 @@ def _assert_warns_context(warning_class, name=None): l = sup.record(warning_class) yield if not len(l) > 0: - name_str = " when calling %s" % name if name is not None else "" + name_str = f' when calling {name}' if name is not None else '' raise AssertionError("No warning raised" + name_str) @@ -1809,8 +1810,8 @@ def _assert_no_warnings_context(name=None): warnings.simplefilter('always') yield if len(l) > 0: - name_str = " when calling %s" % name if name is not None else "" - raise AssertionError("Got warnings%s: %s" % (name_str, l)) + name_str = f' when calling {name}' if name is not None else '' + raise AssertionError(f'Got warnings{name_str}: {l}') def assert_no_warnings(*args, **kwargs): @@ -2007,7 +2008,7 @@ class clear_and_catch_warnings(warnings.catch_warnings): def __init__(self, record=False, modules=()): self.modules = set(modules).union(self.class_modules) self._warnreg_copies = {} - super(clear_and_catch_warnings, self).__init__(record=record) + super().__init__(record=record) def __enter__(self): for mod in self.modules: @@ -2015,10 +2016,10 @@ class clear_and_catch_warnings(warnings.catch_warnings): mod_reg = mod.__warningregistry__ self._warnreg_copies[mod] = mod_reg.copy() mod_reg.clear() - return super(clear_and_catch_warnings, self).__enter__() + return super().__enter__() def __exit__(self, *exc_info): - super(clear_and_catch_warnings, self).__exit__(*exc_info) + super().__exit__(*exc_info) for mod in self.modules: if hasattr(mod, '__warningregistry__'): mod.__warningregistry__.clear() @@ -2322,8 +2323,8 @@ def _assert_no_gc_cycles_context(name=None): break else: raise RuntimeError( - "Unable to fully collect garbage - perhaps a __del__ method is " - "creating more reference cycles?") + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?") gc.set_debug(gc.DEBUG_SAVEALL) yield @@ -2337,7 +2338,7 @@ def _assert_no_gc_cycles_context(name=None): gc.enable() if n_objects_in_cycles: - name_str = " when calling %s" % name if name is not None else "" + name_str = f' when calling {name}' if name is not None else '' raise AssertionError( "Reference cycles were found{}: {} objects were collected, " "of which {} are shown below:{}" @@ -2403,7 +2404,8 @@ def break_cycles(): if IS_PYPY: # interpreter runs now, to call deleted objects' __del__ methods gc.collect() - # one more, just to make sure + # two more, just to make sure + gc.collect() gc.collect() @@ -2440,12 +2442,10 @@ def check_free_memory(free_bytes): try: mem_free = _parse_size(env_value) except ValueError as exc: - raise ValueError('Invalid environment variable {}: {!s}'.format( - env_var, exc)) + raise ValueError(f'Invalid environment variable {env_var}: {exc}') - msg = ('{0} GB memory required, but environment variable ' - 'NPY_AVAILABLE_MEM={1} set'.format( - free_bytes/1e9, env_value)) + msg = (f'{free_bytes/1e9} GB memory required, but environment variable ' + f'NPY_AVAILABLE_MEM={env_value} set') else: mem_free = _get_mem_available() @@ -2455,8 +2455,7 @@ def check_free_memory(free_bytes): "the test.") mem_free = -1 else: - msg = '{0} GB memory required, but {1} GB available'.format( - free_bytes/1e9, mem_free/1e9) + msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available' return msg if mem_free < free_bytes else None @@ -2473,7 +2472,7 @@ def _parse_size(size_str): m = size_re.match(size_str.lower()) if not m or m.group(2) not in suffixes: - raise ValueError("value {!r} not a valid size".format(size_str)) + raise ValueError(f'value {size_str!r} not a valid size') return int(float(m.group(1)) * suffixes[m.group(2)]) |