summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-06-14 14:07:18 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-06-14 14:07:18 -0600
commit8c78b84968e580f24b3705378fb35705a434cdf1 (patch)
treec9f82beeb5a2c3f0301f7984d4b6d19539c35d23 /numpy/testing/_private/utils.py
parent8bf3a4618f1de951c7a4ccdb8bc3e36825a1b744 (diff)
parent75f852edf94a7293e7982ad516bee314d7187c2d (diff)
downloadnumpy-8c78b84968e580f24b3705378fb35705a434cdf1.tar.gz
Merge branch 'main' into matrix_rank-doc-fix
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py48
1 files changed, 24 insertions, 24 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index fb33bdcbd..487aa0b4c 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -17,6 +17,7 @@ from unittest.case import SkipTest
from warnings import WarningMessage
import pprint
+import numpy as np
from numpy.core import(
intp, float32, empty, arange, array_repr, ndarray, isnat, array)
import numpy.linalg.lapack_lite
@@ -34,8 +35,7 @@ __all__ = [
'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare',
- '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles',
- 'break_cycles', 'HAS_LAPACK64'
+ 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64'
]
@@ -378,7 +378,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
try:
isdesnat = isnat(desired)
isactnat = isnat(actual)
- dtypes_match = array(desired).dtype.type == array(actual).dtype.type
+ dtypes_match = (np.asarray(desired).dtype.type ==
+ np.asarray(actual).dtype.type)
if isdesnat and isactnat:
# If both are NaT (and have the same dtype -- datetime or
# timedelta) they are considered equal.
@@ -398,8 +399,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
return # both nan, so equal
# handle signed zero specially for floats
- array_actual = array(actual)
- array_desired = array(desired)
+ array_actual = np.asarray(actual)
+ array_desired = np.asarray(desired)
if (array_actual.dtype.char in 'Mm' or
array_desired.dtype.char in 'Mm'):
# version 1.18
@@ -481,7 +482,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
instead of this function for more consistent floating point
comparisons.
- The test verifies that the elements of ``actual`` and ``desired`` satisfy.
+ The test verifies that the elements of `actual` and `desired` satisfy.
``abs(desired-actual) < 1.5 * 10**(-decimal)``
@@ -516,9 +517,9 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
Examples
--------
- >>> import numpy.testing as npt
- >>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
- >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
+ >>> from numpy.testing import assert_almost_equal
+ >>> assert_almost_equal(2.3333333333333, 2.33333334)
+ >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
Traceback (most recent call last):
...
AssertionError:
@@ -526,8 +527,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
ACTUAL: 2.3333333333333
DESIRED: 2.33333334
- >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
- ... np.array([1.0,2.33333334]), decimal=9)
+ >>> assert_almost_equal(np.array([1.0,2.3333333333333]),
+ ... np.array([1.0,2.33333334]), decimal=9)
Traceback (most recent call last):
...
AssertionError:
@@ -701,8 +702,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
- x = array(x, copy=False, subok=True)
- y = array(y, copy=False, subok=True)
+ x = np.asanyarray(x)
+ y = np.asanyarray(y)
# original array for output formatting
ox, oy = x, y
@@ -745,7 +746,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
# flag as it everywhere, so we should return the scalar flag.
if isinstance(x_id, bool) or x_id.ndim == 0:
return bool_(x_id)
- elif isinstance(x_id, bool) or y_id.ndim == 0:
+ elif isinstance(y_id, bool) or y_id.ndim == 0:
return bool_(y_id)
else:
return y_id
@@ -1033,7 +1034,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
# casting of x later.
dtype = result_type(y, 1.)
- y = array(y, dtype=dtype, copy=False, subok=True)
+ y = np.asanyarray(y, dtype)
z = abs(x - y)
if not issubdtype(z.dtype, number):
@@ -1678,11 +1679,11 @@ def nulp_diff(x, y, dtype=None):
"""
import numpy as np
if dtype:
- x = np.array(x, dtype=dtype)
- y = np.array(y, dtype=dtype)
+ x = np.asarray(x, dtype=dtype)
+ y = np.asarray(y, dtype=dtype)
else:
- x = np.array(x)
- y = np.array(y)
+ x = np.asarray(x)
+ y = np.asarray(y)
t = np.common_type(x, y)
if np.iscomplexobj(x) or np.iscomplexobj(y):
@@ -1699,7 +1700,7 @@ def nulp_diff(x, y, dtype=None):
(x.shape, y.shape))
def _diff(rx, ry, vdt):
- diff = np.array(rx-ry, dtype=vdt)
+ diff = np.asarray(rx-ry, dtype=vdt)
return np.abs(diff)
rx = integer_repr(x)
@@ -2006,7 +2007,7 @@ class clear_and_catch_warnings(warnings.catch_warnings):
def __init__(self, record=False, modules=()):
self.modules = set(modules).union(self.class_modules)
self._warnreg_copies = {}
- super(clear_and_catch_warnings, self).__init__(record=record)
+ super().__init__(record=record)
def __enter__(self):
for mod in self.modules:
@@ -2014,10 +2015,10 @@ class clear_and_catch_warnings(warnings.catch_warnings):
mod_reg = mod.__warningregistry__
self._warnreg_copies[mod] = mod_reg.copy()
mod_reg.clear()
- return super(clear_and_catch_warnings, self).__enter__()
+ return super().__enter__()
def __exit__(self, *exc_info):
- super(clear_and_catch_warnings, self).__exit__(*exc_info)
+ super().__exit__(*exc_info)
for mod in self.modules:
if hasattr(mod, '__warningregistry__'):
mod.__warningregistry__.clear()
@@ -2516,4 +2517,3 @@ def _no_tracing(func):
finally:
sys.settrace(original_trace)
return wrapper
-