summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2018-10-23 12:35:38 +0300
committerGitHub <noreply@github.com>2018-10-23 12:35:38 +0300
commit7fbcc4eaf22a01ac3282179b49c6363485263fbf (patch)
tree336664500a6c426f9ef4cf3ccf624d0e1ce49be2 /numpy/testing
parent2705bd5f9e6c7999e25d42daddd5bd11223472c8 (diff)
parentbe5ea7d92d542e7c7eb055c5831a79850f4bfbee (diff)
downloadnumpy-7fbcc4eaf22a01ac3282179b49c6363485263fbf.tar.gz
Merge pull request #12243 from liwt31/fix_misleading_msg
BUG: Fix misleading assert message in assert_almost_equal #12200
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/_private/utils.py9
-rw-r--r--numpy/testing/tests/test_utils.py18
2 files changed, 22 insertions, 5 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index b37dac69d..20a7dfd0b 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ # original array for output formating
+ ox, oy = x, y
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -785,10 +788,10 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# do not trigger a failure (np.ma.masked != True evaluates as
# np.ma.masked, which is falsy).
if cond != True:
- match = 100-100.0*reduced.count(1)/len(reduced)
- msg = build_err_msg([x, y],
+ mismatch = 100.0 * reduced.count(0) / ox.size
+ msg = build_err_msg([ox, oy],
err_msg
- + '\n(mismatch %s%%)' % (match,),
+ + '\n(mismatch %s%%)' % (mismatch,),
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 8099db0b7..e54fbc390 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -507,7 +507,8 @@ class TestAlmostEqual(_GenericTest):
self._test_not_equal(x, z)
def test_error_message(self):
- """Check the message is formatted correctly for the decimal value"""
+ """Check the message is formatted correctly for the decimal value.
+ Also check the message when input includes inf or nan (gh12200)"""
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])
@@ -531,6 +532,19 @@ class TestAlmostEqual(_GenericTest):
# remove anything that's not the array string
assert_equal(str(e).split('%)\n ')[1], b)
+ # Check the error message when input includes inf or nan
+ x = np.array([np.inf, 0])
+ y = np.array([np.inf, 1])
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ msgs = str(e).split('\n')
+ # assert error percentage is 50%
+ assert_equal(msgs[3], '(mismatch 50.0%)')
+ # assert output array contains inf
+ assert_equal(msgs[4], ' x: array([inf, 0.])')
+ assert_equal(msgs[5], ' y: array([inf, 1.])')
+
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
@@ -1115,7 +1129,7 @@ class TestStringEqual(object):
assert_raises(AssertionError,
lambda: assert_string_equal("foo", "hello"))
-
+
def test_regex(self):
assert_string_equal("a+*b", "a+*b")