diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-06-14 14:07:18 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-06-14 14:07:18 -0600 |
commit | 8c78b84968e580f24b3705378fb35705a434cdf1 (patch) | |
tree | c9f82beeb5a2c3f0301f7984d4b6d19539c35d23 /numpy/testing/_private/utils.py | |
parent | 8bf3a4618f1de951c7a4ccdb8bc3e36825a1b744 (diff) | |
parent | 75f852edf94a7293e7982ad516bee314d7187c2d (diff) | |
download | numpy-8c78b84968e580f24b3705378fb35705a434cdf1.tar.gz |
Merge branch 'main' into matrix_rank-doc-fix
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index fb33bdcbd..487aa0b4c 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -17,6 +17,7 @@ from unittest.case import SkipTest from warnings import WarningMessage import pprint +import numpy as np from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) import numpy.linalg.lapack_lite @@ -34,8 +35,7 @@ __all__ = [ 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare', - '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles', - 'break_cycles', 'HAS_LAPACK64' + 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64' ] @@ -378,7 +378,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): try: isdesnat = isnat(desired) isactnat = isnat(actual) - dtypes_match = array(desired).dtype.type == array(actual).dtype.type + dtypes_match = (np.asarray(desired).dtype.type == + np.asarray(actual).dtype.type) if isdesnat and isactnat: # If both are NaT (and have the same dtype -- datetime or # timedelta) they are considered equal. @@ -398,8 +399,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): return # both nan, so equal # handle signed zero specially for floats - array_actual = array(actual) - array_desired = array(desired) + array_actual = np.asarray(actual) + array_desired = np.asarray(desired) if (array_actual.dtype.char in 'Mm' or array_desired.dtype.char in 'Mm'): # version 1.18 @@ -481,7 +482,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): instead of this function for more consistent floating point comparisons. - The test verifies that the elements of ``actual`` and ``desired`` satisfy. + The test verifies that the elements of `actual` and `desired` satisfy. ``abs(desired-actual) < 1.5 * 10**(-decimal)`` @@ -516,9 +517,9 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): Examples -------- - >>> import numpy.testing as npt - >>> npt.assert_almost_equal(2.3333333333333, 2.33333334) - >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + >>> from numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) Traceback (most recent call last): ... AssertionError: @@ -526,8 +527,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): ACTUAL: 2.3333333333333 DESIRED: 2.33333334 - >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]), - ... np.array([1.0,2.33333334]), decimal=9) + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) Traceback (most recent call last): ... AssertionError: @@ -701,8 +702,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ - x = array(x, copy=False, subok=True) - y = array(y, copy=False, subok=True) + x = np.asanyarray(x) + y = np.asanyarray(y) # original array for output formatting ox, oy = x, y @@ -745,7 +746,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', # flag as it everywhere, so we should return the scalar flag. if isinstance(x_id, bool) or x_id.ndim == 0: return bool_(x_id) - elif isinstance(x_id, bool) or y_id.ndim == 0: + elif isinstance(y_id, bool) or y_id.ndim == 0: return bool_(y_id) else: return y_id @@ -1033,7 +1034,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): # make sure y is an inexact type to avoid abs(MIN_INT); will cause # casting of x later. dtype = result_type(y, 1.) - y = array(y, dtype=dtype, copy=False, subok=True) + y = np.asanyarray(y, dtype) z = abs(x - y) if not issubdtype(z.dtype, number): @@ -1678,11 +1679,11 @@ def nulp_diff(x, y, dtype=None): """ import numpy as np if dtype: - x = np.array(x, dtype=dtype) - y = np.array(y, dtype=dtype) + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) else: - x = np.array(x) - y = np.array(y) + x = np.asarray(x) + y = np.asarray(y) t = np.common_type(x, y) if np.iscomplexobj(x) or np.iscomplexobj(y): @@ -1699,7 +1700,7 @@ def nulp_diff(x, y, dtype=None): (x.shape, y.shape)) def _diff(rx, ry, vdt): - diff = np.array(rx-ry, dtype=vdt) + diff = np.asarray(rx-ry, dtype=vdt) return np.abs(diff) rx = integer_repr(x) @@ -2006,7 +2007,7 @@ class clear_and_catch_warnings(warnings.catch_warnings): def __init__(self, record=False, modules=()): self.modules = set(modules).union(self.class_modules) self._warnreg_copies = {} - super(clear_and_catch_warnings, self).__init__(record=record) + super().__init__(record=record) def __enter__(self): for mod in self.modules: @@ -2014,10 +2015,10 @@ class clear_and_catch_warnings(warnings.catch_warnings): mod_reg = mod.__warningregistry__ self._warnreg_copies[mod] = mod_reg.copy() mod_reg.clear() - return super(clear_and_catch_warnings, self).__enter__() + return super().__enter__() def __exit__(self, *exc_info): - super(clear_and_catch_warnings, self).__exit__(*exc_info) + super().__exit__(*exc_info) for mod in self.modules: if hasattr(mod, '__warningregistry__'): mod.__warningregistry__.clear() @@ -2516,4 +2517,3 @@ def _no_tracing(func): finally: sys.settrace(original_trace) return wrapper - |