summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/tests/test_utils.py68
-rw-r--r--numpy/testing/utils.py28
2 files changed, 86 insertions, 10 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 5956a4294..aa0a2669f 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -134,6 +134,49 @@ class TestArrayEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(c, b)
+class TestBuildErrorMessage(unittest.TestCase):
+ def test_build_err_msg_defaults(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, '
+ '2.00003, 3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_no_verbose(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, verbose=False)
+ b = '\nItems are not equal: There is a mismatch'
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_names(self):
+ x = np.array([1.00001, 2.00002, 3.00003])
+ y = np.array([1.00002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
+ b = ('\nItems are not equal: There is a mismatch\n FOO: array([ '
+ '1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, '
+ '3.00004])')
+ self.assertEqual(a, b)
+
+ def test_build_err_msg_custom_precision(self):
+ x = np.array([1.000000001, 2.00002, 3.00003])
+ y = np.array([1.000000002, 2.00003, 3.00004])
+ err_msg = 'There is a mismatch'
+
+ a = build_err_msg([x, y], err_msg, precision=10)
+ b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
+ '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ '
+ '1.000000002, 2.00003 , 3.00004 ])')
+ self.assertEqual(a, b)
+
class TestEqual(TestArrayEqual):
def setUp(self):
self._assert_func = assert_equal
@@ -239,6 +282,31 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(x, y)
self._test_not_equal(x, z)
+ def test_error_message(self):
+ """Check the message is formatted correctly for the decimal value"""
+ x = np.array([1.00000000001, 2.00000000002, 3.00003])
+ y = np.array([1.00000000002, 2.00000000003, 3.00004])
+
+ # test with a different amount of decimal digits
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 '
+ ' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])')
+ try:
+ self._assert_func(x, y, decimal=12)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
+ # with the default value of decimal digits, only the 3rd element differs
+ # note that we only check for the formatting of the arrays themselves
+ b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , '
+ '2. , 3.00004])')
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ # remove anything that's not the array string
+ self.assertEqual(str(e).split('%)\n ')[1], b)
+
class TestApproxEqual(unittest.TestCase):
def setUp(self):
self._assert_func = assert_approx_equal
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 01c0e769b..930900bf2 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -9,8 +9,9 @@ import sys
import re
import operator
import warnings
+from functools import partial
from .nosetester import import_nose
-from numpy.core import float32, empty, arange
+from numpy.core import float32, empty, arange, array_repr, ndarray
if sys.version_info[0] >= 3:
from io import StringIO
@@ -190,8 +191,7 @@ if os.name=='nt' and sys.version[:3] > '2.3':
win32pdh.PDH_FMT_LONG, None)
def build_err_msg(arrays, err_msg, header='Items are not equal:',
- verbose=True,
- names=('ACTUAL', 'DESIRED')):
+ verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
@@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
msg.append(err_msg)
if verbose:
for i, a in enumerate(arrays):
+
+ if isinstance(a, ndarray):
+ # precision argument is only needed if the objects are ndarrays
+ r_func = partial(array_repr, precision=precision)
+ else:
+ r_func = repr
+
try:
- r = repr(a)
+ r = r_func(a)
except:
r = '[repr failed]'
if r.count('\n') > 3:
@@ -575,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header=''):
+ header='', precision=6):
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -592,7 +599,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:' \
% (hasval), verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise AssertionError(msg)
try:
@@ -603,7 +610,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
+ '\n(shapes %s, %s mismatch)' % (x.shape,
y.shape),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
@@ -648,7 +655,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
err_msg
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
except ValueError as e:
@@ -657,7 +664,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
- names=('x', 'y'))
+ names=('x', 'y'), precision=precision)
raise ValueError(msg)
def assert_array_equal(x, y, err_msg='', verbose=True):
@@ -825,7 +832,8 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
return around(z, decimal) <= 10.0**(-decimal)
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
- header=('Arrays are not almost equal to %d decimals' % decimal))
+ header=('Arrays are not almost equal to %d decimals' % decimal),
+ precision=decimal)
def assert_array_less(x, y, err_msg='', verbose=True):