summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/_private/utils.py15
-rw-r--r--numpy/testing/tests/test_utils.py20
2 files changed, 28 insertions, 7 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 4ac0715bf..8a31fcf15 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -686,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, array2string, isnan, inf, bool_, errstate, all
+ from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -788,17 +788,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# np.ma.masked, which is falsy).
if cond != True:
n_mismatch = reduced.size - reduced.sum(dtype=intp)
- percent_mismatch = 100 * n_mismatch / ox.size
+ n_elements = flagged.size if flagged.ndim != 0 else reduced.size
+ percent_mismatch = 100 * n_mismatch / n_elements
remarks = [
'Mismatched elements: {} / {} ({:.3g}%)'.format(
- n_mismatch, ox.size, percent_mismatch)]
+ n_mismatch, n_elements, percent_mismatch)]
with errstate(invalid='ignore', divide='ignore'):
# ignore errors for non-numeric types
with contextlib.suppress(TypeError):
error = abs(x - y)
- max_abs_error = error.max()
- if error.dtype == 'object':
+ max_abs_error = max(error)
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max absolute difference: '
+ str(max_abs_error))
else:
@@ -812,8 +813,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
if all(~nonzero):
max_rel_error = array(inf)
else:
- max_rel_error = (error[nonzero] / abs(y[nonzero])).max()
- if error.dtype == 'object':
+ max_rel_error = max(error[nonzero] / abs(y[nonzero]))
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max relative difference: '
+ str(max_rel_error))
else:
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 688bedc16..44f93a693 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -564,6 +564,26 @@ class TestAlmostEqual(_GenericTest):
assert_equal(msgs[4], 'Max absolute difference: 2')
assert_equal(msgs[5], 'Max relative difference: inf')
+ def test_error_message_2(self):
+ """Check the message is formatted correctly when either x or y is a scalar."""
+ x = 2
+ y = np.ones(20)
+ with pytest.raises(AssertionError) as exc_info:
+ self._assert_func(x, y)
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
+ assert_equal(msgs[4], 'Max absolute difference: 1.')
+ assert_equal(msgs[5], 'Max relative difference: 1.')
+
+ y = 2
+ x = np.ones(20)
+ with pytest.raises(AssertionError) as exc_info:
+ self._assert_func(x, y)
+ msgs = str(exc_info.value).split('\n')
+ assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
+ assert_equal(msgs[4], 'Max absolute difference: 1.')
+ assert_equal(msgs[5], 'Max relative difference: 0.5')
+
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having