diff options
author | Anthony Vo <anthonyhvo12@gmail.com> | 2021-04-05 23:27:23 +0700 |
---|---|---|
committer | Anthony Vo <anthonyhvo12@gmail.com> | 2021-04-05 23:27:23 +0700 |
commit | e4856c1197274a4b57b6ddc0e8ea7d7e4854986d (patch) | |
tree | d2a5dd5209cdd367a953b8c25f625cf94300f464 /numpy/testing/_private/utils.py | |
parent | 2c1410becc7fbe660426e2a946d54304fc470148 (diff) | |
parent | 7bb6a502ebaecd829e3c763e9f90220835e7b733 (diff) | |
download | numpy-e4856c1197274a4b57b6ddc0e8ea7d7e4854986d.tar.gz |
Merge branch 'main' of https://github.com/numpy/numpy into avo-exceptions-chaining
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index b4d42728e..1bdb00fd5 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -17,6 +17,7 @@ from unittest.case import SkipTest from warnings import WarningMessage import pprint +import numpy as np from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) import numpy.linalg.lapack_lite @@ -378,7 +379,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): try: isdesnat = isnat(desired) isactnat = isnat(actual) - dtypes_match = array(desired).dtype.type == array(actual).dtype.type + dtypes_match = (np.asarray(desired).dtype.type == + np.asarray(actual).dtype.type) if isdesnat and isactnat: # If both are NaT (and have the same dtype -- datetime or # timedelta) they are considered equal. @@ -398,8 +400,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): return # both nan, so equal # handle signed zero specially for floats - array_actual = array(actual) - array_desired = array(desired) + array_actual = np.asarray(actual) + array_desired = np.asarray(desired) if (array_actual.dtype.char in 'Mm' or array_desired.dtype.char in 'Mm'): # version 1.18 @@ -701,8 +703,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ - x = array(x, copy=False, subok=True) - y = array(y, copy=False, subok=True) + x = np.asanyarray(x) + y = np.asanyarray(y) # original array for output formatting ox, oy = x, y @@ -1033,7 +1035,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): # make sure y is an inexact type to avoid abs(MIN_INT); will cause # casting of x later. dtype = result_type(y, 1.) - y = array(y, dtype=dtype, copy=False, subok=True) + y = np.asanyarray(y, dtype) z = abs(x - y) if not issubdtype(z.dtype, number): @@ -1678,11 +1680,11 @@ def nulp_diff(x, y, dtype=None): """ import numpy as np if dtype: - x = np.array(x, dtype=dtype) - y = np.array(y, dtype=dtype) + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) else: - x = np.array(x) - y = np.array(y) + x = np.asarray(x) + y = np.asarray(y) t = np.common_type(x, y) if np.iscomplexobj(x) or np.iscomplexobj(y): @@ -1699,7 +1701,7 @@ def nulp_diff(x, y, dtype=None): (x.shape, y.shape)) def _diff(rx, ry, vdt): - diff = np.array(rx-ry, dtype=vdt) + diff = np.asarray(rx-ry, dtype=vdt) return np.abs(diff) rx = integer_repr(x) @@ -2006,7 +2008,7 @@ class clear_and_catch_warnings(warnings.catch_warnings): def __init__(self, record=False, modules=()): self.modules = set(modules).union(self.class_modules) self._warnreg_copies = {} - super(clear_and_catch_warnings, self).__init__(record=record) + super().__init__(record=record) def __enter__(self): for mod in self.modules: @@ -2014,10 +2016,10 @@ class clear_and_catch_warnings(warnings.catch_warnings): mod_reg = mod.__warningregistry__ self._warnreg_copies[mod] = mod_reg.copy() mod_reg.clear() - return super(clear_and_catch_warnings, self).__enter__() + return super().__enter__() def __exit__(self, *exc_info): - super(clear_and_catch_warnings, self).__exit__(*exc_info) + super().__exit__(*exc_info) for mod in self.modules: if hasattr(mod, '__warningregistry__'): mod.__warningregistry__.clear() |