summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
authorAnthony Vo <anthonyhvo12@gmail.com>2021-04-05 23:27:23 +0700
committerAnthony Vo <anthonyhvo12@gmail.com>2021-04-05 23:27:23 +0700
commite4856c1197274a4b57b6ddc0e8ea7d7e4854986d (patch)
treed2a5dd5209cdd367a953b8c25f625cf94300f464 /numpy/testing/_private/utils.py
parent2c1410becc7fbe660426e2a946d54304fc470148 (diff)
parent7bb6a502ebaecd829e3c763e9f90220835e7b733 (diff)
downloadnumpy-e4856c1197274a4b57b6ddc0e8ea7d7e4854986d.tar.gz
Merge branch 'main' of https://github.com/numpy/numpy into avo-exceptions-chaining
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index b4d42728e..1bdb00fd5 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -17,6 +17,7 @@ from unittest.case import SkipTest
from warnings import WarningMessage
import pprint
+import numpy as np
from numpy.core import(
intp, float32, empty, arange, array_repr, ndarray, isnat, array)
import numpy.linalg.lapack_lite
@@ -378,7 +379,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
try:
isdesnat = isnat(desired)
isactnat = isnat(actual)
- dtypes_match = array(desired).dtype.type == array(actual).dtype.type
+ dtypes_match = (np.asarray(desired).dtype.type ==
+ np.asarray(actual).dtype.type)
if isdesnat and isactnat:
# If both are NaT (and have the same dtype -- datetime or
# timedelta) they are considered equal.
@@ -398,8 +400,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
return # both nan, so equal
# handle signed zero specially for floats
- array_actual = array(actual)
- array_desired = array(desired)
+ array_actual = np.asarray(actual)
+ array_desired = np.asarray(desired)
if (array_actual.dtype.char in 'Mm' or
array_desired.dtype.char in 'Mm'):
# version 1.18
@@ -701,8 +703,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
- x = array(x, copy=False, subok=True)
- y = array(y, copy=False, subok=True)
+ x = np.asanyarray(x)
+ y = np.asanyarray(y)
# original array for output formatting
ox, oy = x, y
@@ -1033,7 +1035,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
# casting of x later.
dtype = result_type(y, 1.)
- y = array(y, dtype=dtype, copy=False, subok=True)
+ y = np.asanyarray(y, dtype)
z = abs(x - y)
if not issubdtype(z.dtype, number):
@@ -1678,11 +1680,11 @@ def nulp_diff(x, y, dtype=None):
"""
import numpy as np
if dtype:
- x = np.array(x, dtype=dtype)
- y = np.array(y, dtype=dtype)
+ x = np.asarray(x, dtype=dtype)
+ y = np.asarray(y, dtype=dtype)
else:
- x = np.array(x)
- y = np.array(y)
+ x = np.asarray(x)
+ y = np.asarray(y)
t = np.common_type(x, y)
if np.iscomplexobj(x) or np.iscomplexobj(y):
@@ -1699,7 +1701,7 @@ def nulp_diff(x, y, dtype=None):
(x.shape, y.shape))
def _diff(rx, ry, vdt):
- diff = np.array(rx-ry, dtype=vdt)
+ diff = np.asarray(rx-ry, dtype=vdt)
return np.abs(diff)
rx = integer_repr(x)
@@ -2006,7 +2008,7 @@ class clear_and_catch_warnings(warnings.catch_warnings):
def __init__(self, record=False, modules=()):
self.modules = set(modules).union(self.class_modules)
self._warnreg_copies = {}
- super(clear_and_catch_warnings, self).__init__(record=record)
+ super().__init__(record=record)
def __enter__(self):
for mod in self.modules:
@@ -2014,10 +2016,10 @@ class clear_and_catch_warnings(warnings.catch_warnings):
mod_reg = mod.__warningregistry__
self._warnreg_copies[mod] = mod_reg.copy()
mod_reg.clear()
- return super(clear_and_catch_warnings, self).__enter__()
+ return super().__enter__()
def __exit__(self, *exc_info):
- super(clear_and_catch_warnings, self).__exit__(*exc_info)
+ super().__exit__(*exc_info)
for mod in self.modules:
if hasattr(mod, '__warningregistry__'):
mod.__warningregistry__.clear()