summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 01c0e769b..930900bf2 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -9,8 +9,9 @@ import sys
import re
import operator
import warnings
+from functools import partial
from .nosetester import import_nose
-from numpy.core import float32, empty, arange
+from numpy.core import float32, empty, arange, array_repr, ndarray
if sys.version_info[0] >= 3:
from io import StringIO
@@ -190,8 +191,7 @@ if os.name=='nt' and sys.version[:3] > '2.3':
win32pdh.PDH_FMT_LONG, None)
def build_err_msg(arrays, err_msg, header='Items are not equal:',
- verbose=True,
- names=('ACTUAL', 'DESIRED')):
+ verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
@@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
msg.append(err_msg)
if verbose:
for i, a in enumerate(arrays):
+
+ if isinstance(a, ndarray):
+ # precision argument is only needed if the objects are ndarrays
+ r_func = partial(array_repr, precision=precision)
+ else:
+ r_func = repr
+
try:
- r = repr(a)
+ r = r_func(a)
except:
r = '[repr failed]'
if r.count('\n') > 3:
@@ -575,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header=''):
+ header='', precision=6):
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -592,7 +599,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:' \
% (hasval), verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise AssertionError(msg)
try:
@@ -603,7 +610,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
+ '\n(shapes %s, %s mismatch)' % (x.shape,
y.shape),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
@@ -648,7 +655,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
err_msg
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
except ValueError as e:
@@ -657,7 +664,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise ValueError(msg)
def assert_array_equal(x, y, err_msg='', verbose=True):
@@ -825,7 +832,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
return around(z, decimal) <= 10.0**(-decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
- header=('Arrays are not almost equal to %d decimals' % decimal))
+ header=('Arrays are not almost equal to %d decimals' % decimal),
+ precision=decimal)
def assert_array_less(x, y, err_msg='', verbose=True):