summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
authorMaxwell Aladago <maxwell.aladago@gmail.com>2019-09-06 14:39:42 -0400
committerMaxwell Aladago <maxwell.aladago@gmail.com>2019-09-06 14:39:42 -0400
commit151802288ba4ffb2595ac3f6e07ee08b654863b6 (patch)
tree7a9e3f829ebcce318a214ffc78e18b2743774910 /numpy/testing/_private/utils.py
parent94ae1c59d78e64a2eda219ade28f1180e3c2d9af (diff)
downloadnumpy-151802288ba4ffb2595ac3f6e07ee08b654863b6.tar.gz
assert_array_compare
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 4ac0715bf..8a31fcf15 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -686,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, array2string, isnan, inf, bool_, errstate, all
+ from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -788,17 +788,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# np.ma.masked, which is falsy).
if cond != True:
n_mismatch = reduced.size - reduced.sum(dtype=intp)
- percent_mismatch = 100 * n_mismatch / ox.size
+ n_elements = flagged.size if flagged.ndim != 0 else reduced.size
+ percent_mismatch = 100 * n_mismatch / n_elements
remarks = [
'Mismatched elements: {} / {} ({:.3g}%)'.format(
- n_mismatch, ox.size, percent_mismatch)]
+ n_mismatch, n_elements, percent_mismatch)]
with errstate(invalid='ignore', divide='ignore'):
# ignore errors for non-numeric types
with contextlib.suppress(TypeError):
error = abs(x - y)
- max_abs_error = error.max()
- if error.dtype == 'object':
+ max_abs_error = max(error)
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max absolute difference: '
+ str(max_abs_error))
else:
@@ -812,8 +813,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
if all(~nonzero):
max_rel_error = array(inf)
else:
- max_rel_error = (error[nonzero] / abs(y[nonzero])).max()
- if error.dtype == 'object':
+ max_rel_error = max(error[nonzero] / abs(y[nonzero]))
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max relative difference: '
+ str(max_rel_error))
else: