summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py131
1 files changed, 65 insertions, 66 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 3827b7505..393fedc27 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -17,6 +17,7 @@ from unittest.case import SkipTest
from warnings import WarningMessage
import pprint
+import numpy as np
from numpy.core import(
intp, float32, empty, arange, array_repr, ndarray, isnat, array)
import numpy.linalg.lapack_lite
@@ -113,14 +114,14 @@ def gisnan(x):
def gisfinite(x):
- """like isfinite, but always raise an error if type not supported instead of
- returning a TypeError object.
+ """like isfinite, but always raise an error if type not supported instead
+ of returning a TypeError object.
Notes
-----
- isfinite and other ufunc sometimes return a NotImplementedType object instead
- of raising any exception. This function is a wrapper to make sure an
- exception is always raised.
+ isfinite and other ufunc sometimes return a NotImplementedType object
+ instead of raising any exception. This function is a wrapper to make sure
+ an exception is always raised.
This should be removed once this problem is solved at the Ufunc level."""
from numpy.core import isfinite, errstate
@@ -160,12 +161,13 @@ if os.name == 'nt':
# you should copy this function, but keep the counter open, and call
# CollectQueryData() each time you need to know.
# See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link)
- # My older explanation for this was that the "AddCounter" process forced
- # the CPU to 100%, but the above makes more sense :)
+ # My older explanation for this was that the "AddCounter" process
+ # forced the CPU to 100%, but the above makes more sense :)
import win32pdh
if format is None:
format = win32pdh.PDH_FMT_LONG
- path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter))
+ path = win32pdh.MakeCounterPath( (machine, object, instance, None,
+ inum, counter))
hq = win32pdh.OpenQuery()
try:
hc = win32pdh.AddCounter(hq, path)
@@ -186,7 +188,7 @@ if os.name == 'nt':
win32pdh.PDH_FMT_LONG, None)
elif sys.platform[:5] == 'linux':
- def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())):
+ def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'):
"""
Return virtual memory size in bytes of the running python.
@@ -207,8 +209,7 @@ else:
if sys.platform[:5] == 'linux':
- def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()),
- _load_time=[]):
+ def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]):
"""
Return number of jiffies elapsed.
@@ -263,11 +264,11 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
try:
r = r_func(a)
except Exception as exc:
- r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
+ r = f'[repr failed for <{type(a).__name__}>: {exc}]'
if r.count('\n') > 3:
r = '\n'.join(r.splitlines()[:3])
r += '...'
- msg.append(' %s: %s' % (names[i], r))
+ msg.append(f' {names[i]}: {r}')
return '\n'.join(msg)
@@ -329,12 +330,14 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
for k, i in desired.items():
if k not in actual:
raise AssertionError(repr(k))
- assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
+ assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}',
+ verbose)
return
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
assert_equal(len(actual), len(desired), err_msg, verbose)
for k in range(len(desired)):
- assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose)
+ assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}',
+ verbose)
return
from numpy.core import ndarray, isscalar, signbit
from numpy.lib import iscomplexobj, real, imag
@@ -376,7 +379,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
try:
isdesnat = isnat(desired)
isactnat = isnat(actual)
- dtypes_match = array(desired).dtype.type == array(actual).dtype.type
+ dtypes_match = (np.asarray(desired).dtype.type ==
+ np.asarray(actual).dtype.type)
if isdesnat and isactnat:
# If both are NaT (and have the same dtype -- datetime or
# timedelta) they are considered equal.
@@ -396,8 +400,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
return # both nan, so equal
# handle signed zero specially for floats
- array_actual = array(actual)
- array_desired = array(desired)
+ array_actual = np.asarray(actual)
+ array_desired = np.asarray(desired)
if (array_actual.dtype.char in 'Mm' or
array_desired.dtype.char in 'Mm'):
# version 1.18
@@ -479,7 +483,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test verifies that the elements of ``actual`` and ``desired`` satisfy.
+ The test verifies that the elements of `actual` and `desired` satisfy.
``abs(desired-actual) < 1.5 * 10**(-decimal)``
@@ -514,9 +518,9 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
Examples
--------
- >>> import numpy.testing as npt
- >>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
- >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
+ >>> from numpy.testing import assert_almost_equal
+ >>> assert_almost_equal(2.3333333333333, 2.33333334)
+ >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
Traceback (most recent call last):
...
AssertionError:
@@ -524,8 +528,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
ACTUAL: 2.3333333333333
DESIRED: 2.33333334
- >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
- ... np.array([1.0,2.33333334]), decimal=9)
+ >>> assert_almost_equal(np.array([1.0,2.3333333333333]),
+ ... np.array([1.0,2.33333334]), decimal=9)
Traceback (most recent call last):
...
AssertionError:
@@ -694,14 +698,13 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
-def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header='', precision=6, equal_nan=True,
- equal_inf=True):
+def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
+ precision=6, equal_nan=True, equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
- x = array(x, copy=False, subok=True)
- y = array(y, copy=False, subok=True)
+ x = np.asanyarray(x)
+ y = np.asanyarray(y)
# original array for output formatting
ox, oy = x, y
@@ -744,7 +747,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# flag as it everywhere, so we should return the scalar flag.
if isinstance(x_id, bool) or x_id.ndim == 0:
return bool_(x_id)
- elif isinstance(x_id, bool) or y_id.ndim == 0:
+ elif isinstance(y_id, bool) or y_id.ndim == 0:
return bool_(y_id)
else:
return y_id
@@ -754,8 +757,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
if not cond:
msg = build_err_msg([x, y],
err_msg
- + '\n(shapes %s, %s mismatch)' % (x.shape,
- y.shape),
+ + f'\n(shapes {x.shape}, {y.shape} mismatch)',
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
@@ -843,7 +845,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
except ValueError:
import traceback
efmt = traceback.format_exc()
- header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
+ header = f'error during assertion:\n\n{efmt}\n\n{header}'
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
@@ -1033,7 +1035,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
# casting of x later.
dtype = result_type(y, 1.)
- y = array(y, dtype=dtype, copy=False, subok=True)
+ y = np.asanyarray(y, dtype)
z = abs(x - y)
if not issubdtype(z.dtype, number):
@@ -1170,7 +1172,8 @@ def assert_string_equal(actual, desired):
if desired == actual:
return
- diff = list(difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)))
+ diff = list(difflib.Differ().compare(actual.splitlines(True),
+ desired.splitlines(True)))
diff_list = []
while diff:
d1 = diff.pop(0)
@@ -1198,7 +1201,7 @@ def assert_string_equal(actual, desired):
raise AssertionError(repr(d1))
if not diff_list:
return
- msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
+ msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}"
if actual != desired:
raise AssertionError(msg)
@@ -1434,9 +1437,7 @@ def measure(code_str, times=1, label=None):
frame = sys._getframe(1)
locs, globs = frame.f_locals, frame.f_globals
- code = compile(code_str,
- 'Test name: %s ' % label,
- 'exec')
+ code = compile(code_str, f'Test name: {label} ', 'exec')
i = 0
elapsed = jiffies()
while i < times:
@@ -1525,7 +1526,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
equal_nan=equal_nan)
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
- header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
+ header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
verbose=verbose, header=header, equal_nan=equal_nan)
@@ -1679,11 +1680,11 @@ def nulp_diff(x, y, dtype=None):
"""
import numpy as np
if dtype:
- x = np.array(x, dtype=dtype)
- y = np.array(y, dtype=dtype)
+ x = np.asarray(x, dtype=dtype)
+ y = np.asarray(y, dtype=dtype)
else:
- x = np.array(x)
- y = np.array(y)
+ x = np.asarray(x)
+ y = np.asarray(y)
t = np.common_type(x, y)
if np.iscomplexobj(x) or np.iscomplexobj(y):
@@ -1700,7 +1701,7 @@ def nulp_diff(x, y, dtype=None):
(x.shape, y.shape))
def _diff(rx, ry, vdt):
- diff = np.array(rx-ry, dtype=vdt)
+ diff = np.asarray(rx-ry, dtype=vdt)
return np.abs(diff)
rx = integer_repr(x)
@@ -1724,8 +1725,8 @@ def _integer_repr(x, vdt, comp):
def integer_repr(x):
- """Return the signed-magnitude interpretation of the binary representation of
- x."""
+ """Return the signed-magnitude interpretation of the binary representation
+ of x."""
import numpy as np
if x.dtype == np.float16:
return _integer_repr(x, np.int16, np.int16(-2**15))
@@ -1734,7 +1735,7 @@ def integer_repr(x):
elif x.dtype == np.float64:
return _integer_repr(x, np.int64, np.int64(-2**63))
else:
- raise ValueError("Unsupported dtype %s" % x.dtype)
+ raise ValueError(f'Unsupported dtype {x.dtype}')
@contextlib.contextmanager
@@ -1744,7 +1745,7 @@ def _assert_warns_context(warning_class, name=None):
l = sup.record(warning_class)
yield
if not len(l) > 0:
- name_str = " when calling %s" % name if name is not None else ""
+ name_str = f' when calling {name}' if name is not None else ''
raise AssertionError("No warning raised" + name_str)
@@ -1809,8 +1810,8 @@ def _assert_no_warnings_context(name=None):
warnings.simplefilter('always')
yield
if len(l) > 0:
- name_str = " when calling %s" % name if name is not None else ""
- raise AssertionError("Got warnings%s: %s" % (name_str, l))
+ name_str = f' when calling {name}' if name is not None else ''
+ raise AssertionError(f'Got warnings{name_str}: {l}')
def assert_no_warnings(*args, **kwargs):
@@ -2007,7 +2008,7 @@ class clear_and_catch_warnings(warnings.catch_warnings):
def __init__(self, record=False, modules=()):
self.modules = set(modules).union(self.class_modules)
self._warnreg_copies = {}
- super(clear_and_catch_warnings, self).__init__(record=record)
+ super().__init__(record=record)
def __enter__(self):
for mod in self.modules:
@@ -2015,10 +2016,10 @@ class clear_and_catch_warnings(warnings.catch_warnings):
mod_reg = mod.__warningregistry__
self._warnreg_copies[mod] = mod_reg.copy()
mod_reg.clear()
- return super(clear_and_catch_warnings, self).__enter__()
+ return super().__enter__()
def __exit__(self, *exc_info):
- super(clear_and_catch_warnings, self).__exit__(*exc_info)
+ super().__exit__(*exc_info)
for mod in self.modules:
if hasattr(mod, '__warningregistry__'):
mod.__warningregistry__.clear()
@@ -2322,8 +2323,8 @@ def _assert_no_gc_cycles_context(name=None):
break
else:
raise RuntimeError(
- "Unable to fully collect garbage - perhaps a __del__ method is "
- "creating more reference cycles?")
+ "Unable to fully collect garbage - perhaps a __del__ method "
+ "is creating more reference cycles?")
gc.set_debug(gc.DEBUG_SAVEALL)
yield
@@ -2337,7 +2338,7 @@ def _assert_no_gc_cycles_context(name=None):
gc.enable()
if n_objects_in_cycles:
- name_str = " when calling %s" % name if name is not None else ""
+ name_str = f' when calling {name}' if name is not None else ''
raise AssertionError(
"Reference cycles were found{}: {} objects were collected, "
"of which {} are shown below:{}"
@@ -2403,7 +2404,8 @@ def break_cycles():
if IS_PYPY:
# interpreter runs now, to call deleted objects' __del__ methods
gc.collect()
- # one more, just to make sure
+ # two more, just to make sure
+ gc.collect()
gc.collect()
@@ -2440,12 +2442,10 @@ def check_free_memory(free_bytes):
try:
mem_free = _parse_size(env_value)
except ValueError as exc:
- raise ValueError('Invalid environment variable {}: {!s}'.format(
- env_var, exc))
+ raise ValueError(f'Invalid environment variable {env_var}: {exc}')
- msg = ('{0} GB memory required, but environment variable '
- 'NPY_AVAILABLE_MEM={1} set'.format(
- free_bytes/1e9, env_value))
+ msg = (f'{free_bytes/1e9} GB memory required, but environment variable '
+ f'NPY_AVAILABLE_MEM={env_value} set')
else:
mem_free = _get_mem_available()
@@ -2455,8 +2455,7 @@ def check_free_memory(free_bytes):
"the test.")
mem_free = -1
else:
- msg = '{0} GB memory required, but {1} GB available'.format(
- free_bytes/1e9, mem_free/1e9)
+ msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available'
return msg if mem_free < free_bytes else None
@@ -2473,7 +2472,7 @@ def _parse_size(size_str):
m = size_re.match(size_str.lower())
if not m or m.group(2) not in suffixes:
- raise ValueError("value {!r} not a valid size".format(size_str))
+ raise ValueError(f'value {size_str!r} not a valid size')
return int(float(m.group(1)) * suffixes[m.group(2)])