summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py103
1 files changed, 67 insertions, 36 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 23267a9e1..4097a6738 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -2,8 +2,6 @@
Utility function to facilitate testing.
"""
-from __future__ import division, absolute_import, print_function
-
import os
import sys
import platform
@@ -21,12 +19,9 @@ import pprint
from numpy.core import(
intp, float32, empty, arange, array_repr, ndarray, isnat, array)
-import numpy.__config__
+import numpy.linalg.lapack_lite
-if sys.version_info[0] >= 3:
- from io import StringIO
-else:
- from StringIO import StringIO
+from io import StringIO
__all__ = [
'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
@@ -54,7 +49,7 @@ verbose = 0
IS_PYPY = platform.python_implementation() == 'PyPy'
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None
-HAS_LAPACK64 = hasattr(numpy.__config__, 'lapack_ilp64_opt_info')
+HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64
def import_nose():
@@ -197,9 +192,8 @@ elif sys.platform[:5] == 'linux':
"""
try:
- f = open(_proc_pid_stat, 'r')
- l = f.readline().split(' ')
- f.close()
+ with open(_proc_pid_stat, 'r') as f:
+ l = f.readline().split(' ')
return int(l[22])
except Exception:
return
@@ -226,9 +220,8 @@ if sys.platform[:5] == 'linux':
if not _load_time:
_load_time.append(time.time())
try:
- f = open(_proc_pid_stat, 'r')
- l = f.readline().split(' ')
- f.close()
+ with open(_proc_pid_stat, 'r') as f:
+ l = f.readline().split(' ')
return int(l[13])
except Exception:
return int(100*(time.time()-_load_time[0]))
@@ -291,7 +284,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
the scalar.
This function handles NaN comparisons as if NaN was a "normal" number.
- That is, no assertion is raised if both objects have NaNs in the same
+ That is, AssertionError is not raised if both objects have NaNs in the same
positions. This is in contrast to the IEEE standard on NaNs, which says
that NaN compared to anything must return False.
@@ -537,7 +530,8 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
...
AssertionError:
Arrays are not almost equal to 9 decimals
- Mismatch: 50%
+ <BLANKLINE>
+ Mismatched elements: 1 / 2 (50%)
Max absolute difference: 6.66669964e-09
Max relative difference: 2.85715698e-09
x: array([1. , 2.333333333])
@@ -911,7 +905,8 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
...
AssertionError:
Arrays are not equal
- Mismatch: 33.3%
+ <BLANKLINE>
+ Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 4.4408921e-16
Max relative difference: 1.41357986e-16
x: array([1. , 3.141593, nan])
@@ -994,7 +989,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
...
AssertionError:
Arrays are not almost equal to 5 decimals
- Mismatch: 33.3%
+ <BLANKLINE>
+ Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 6.e-05
Max relative difference: 2.57136612e-05
x: array([1. , 2.33333, nan])
@@ -1006,6 +1002,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
...
AssertionError:
Arrays are not almost equal to 5 decimals
+ <BLANKLINE>
x and y nan location mismatch:
x: array([1. , 2.33333, nan])
y: array([1. , 2.33333, 5. ])
@@ -1093,7 +1090,8 @@ def assert_array_less(x, y, err_msg='', verbose=True):
...
AssertionError:
Arrays are not less-ordered
- Mismatch: 33.3%
+ <BLANKLINE>
+ Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 1.
Max relative difference: 0.5
x: array([ 1., 1., nan])
@@ -1104,7 +1102,8 @@ def assert_array_less(x, y, err_msg='', verbose=True):
...
AssertionError:
Arrays are not less-ordered
- Mismatch: 50%
+ <BLANKLINE>
+ Mismatched elements: 1 / 2 (50%)
Max absolute difference: 2.
Max relative difference: 0.66666667
x: array([1., 4.])
@@ -1115,6 +1114,7 @@ def assert_array_less(x, y, err_msg='', verbose=True):
...
AssertionError:
Arrays are not less-ordered
+ <BLANKLINE>
(shapes (3,), (1,) mismatch)
x: array([1., 2., 3.])
y: array([4])
@@ -1346,14 +1346,7 @@ def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
"""
__tracebackhide__ = True # Hide traceback for py.test
-
- if sys.version_info.major >= 3:
- funcname = _d.assertRaisesRegex
- else:
- # Only present in Python 2.7, missing from unittest in 2.6
- funcname = _d.assertRaisesRegexp
-
- return funcname(exception_class, expected_regexp, *args, **kwargs)
+ return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs)
def decorate_methods(cls, decorator, testmatch=None):
@@ -1458,7 +1451,9 @@ def _assert_valid_refcount(op):
"""
if not HAS_REFCOUNT:
return True
- import numpy as np, gc
+
+ import gc
+ import numpy as np
b = np.arange(100*100).reshape(100, 100)
c = b
@@ -1621,6 +1616,12 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
AssertionError
If one or more elements differ by more than `maxulp`.
+ Notes
+ -----
+ For computing the ULP difference, this API does not differentiate between
+ various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
+ is zero).
+
See Also
--------
assert_array_almost_equal_nulp : Compare two arrays relatively to their
@@ -1636,8 +1637,9 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
import numpy as np
ret = nulp_diff(a, b, dtype)
if not np.all(ret <= maxulp):
- raise AssertionError("Arrays are not almost equal up to %g ULP" %
- maxulp)
+ raise AssertionError("Arrays are not almost equal up to %g "
+ "ULP (max difference is %g ULP)" %
+ (maxulp, np.max(ret)))
return ret
@@ -1660,6 +1662,12 @@ def nulp_diff(x, y, dtype=None):
number of representable floating point numbers between each item in x
and y.
+ Notes
+ -----
+ For computing the ULP difference, this API does not differentiate between
+ various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
+ is zero).
+
Examples
--------
# By definition, epsilon is the smallest number such as 1 + eps != 1, so
@@ -1679,8 +1687,11 @@ def nulp_diff(x, y, dtype=None):
if np.iscomplexobj(x) or np.iscomplexobj(y):
raise NotImplementedError("_nulp not implemented for complex array")
- x = np.array(x, dtype=t)
- y = np.array(y, dtype=t)
+ x = np.array([x], dtype=t)
+ y = np.array([y], dtype=t)
+
+ x[np.isnan(x)] = np.nan
+ y[np.isnan(y)] = np.nan
if not x.shape == y.shape:
raise ValueError("x and y do not have the same shape: %s - %s" %
@@ -2002,7 +2013,7 @@ class clear_and_catch_warnings(warnings.catch_warnings):
mod.__warningregistry__.update(self._warnreg_copies[mod])
-class suppress_warnings(object):
+class suppress_warnings:
"""
Context manager and decorator doing much the same as
``warnings.catch_warnings``.
@@ -2217,8 +2228,7 @@ class suppress_warnings(object):
del self._filters
def _showwarning(self, message, category, filename, lineno,
- *args, **kwargs):
- use_warnmsg = kwargs.pop("use_warnmsg", None)
+ *args, use_warnmsg=None, **kwargs):
for cat, _, pattern, mod, rec in (
self._suppressions + self._tmp_suppressions)[::-1]:
if (issubclass(category, cat) and
@@ -2476,3 +2486,24 @@ def _get_mem_available():
return info['memfree'] + info['cached']
return None
+
+
+def _no_tracing(func):
+ """
+ Decorator to temporarily turn off tracing for the duration of a test.
+ Needed in tests that check refcounting, otherwise the tracing itself
+ influences the refcounts
+ """
+ if not hasattr(sys, 'gettrace'):
+ return func
+ else:
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ original_trace = sys.gettrace()
+ try:
+ sys.settrace(None)
+ return func(*args, **kwargs)
+ finally:
+ sys.settrace(original_trace)
+ return wrapper
+