diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/_private/utils.py | 9 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 18 |
2 files changed, 22 insertions, 5 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index b37dac69d..20a7dfd0b 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + # original array for output formating + ox, oy = x, y + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -785,10 +788,10 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, # do not trigger a failure (np.ma.masked != True evaluates as # np.ma.masked, which is falsy). if cond != True: - match = 100-100.0*reduced.count(1)/len(reduced) - msg = build_err_msg([x, y], + mismatch = 100.0 * reduced.count(0) / ox.size + msg = build_err_msg([ox, oy], err_msg - + '\n(mismatch %s%%)' % (match,), + + '\n(mismatch %s%%)' % (mismatch,), verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 8099db0b7..e54fbc390 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -507,7 +507,8 @@ class TestAlmostEqual(_GenericTest): self._test_not_equal(x, z) def test_error_message(self): - """Check the message is formatted correctly for the decimal value""" + """Check the message is formatted correctly for the decimal value. + Also check the message when input includes inf or nan (gh12200)""" x = np.array([1.00000000001, 2.00000000002, 3.00003]) y = np.array([1.00000000002, 2.00000000003, 3.00004]) @@ -531,6 +532,19 @@ class TestAlmostEqual(_GenericTest): # remove anything that's not the array string assert_equal(str(e).split('%)\n ')[1], b) + # Check the error message when input includes inf or nan + x = np.array([np.inf, 0]) + y = np.array([np.inf, 1]) + try: + self._assert_func(x, y) + except AssertionError as e: + msgs = str(e).split('\n') + # assert error percentage is 50% + assert_equal(msgs[3], '(mismatch 50.0%)') + # assert output array contains inf + assert_equal(msgs[4], ' x: array([inf, 0.])') + assert_equal(msgs[5], ' y: array([inf, 1.])') + def test_subclass_that_cannot_be_bool(self): # While we cannot guarantee testing functions will always work for # subclasses, the tests should ideally rely only on subclasses having @@ -1115,7 +1129,7 @@ class TestStringEqual(object): assert_raises(AssertionError, lambda: assert_string_equal("foo", "hello")) - + def test_regex(self): assert_string_equal("a+*b", "a+*b") |