summaryrefslogtreecommitdiff
path: root/numpy/testing/_private
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/_private')
-rw-r--r--numpy/testing/_private/parameterized.py19
-rw-r--r--numpy/testing/_private/utils.py20
2 files changed, 22 insertions, 17 deletions
diff --git a/numpy/testing/_private/parameterized.py b/numpy/testing/_private/parameterized.py
index a5fa4fb5e..489d8e09a 100644
--- a/numpy/testing/_private/parameterized.py
+++ b/numpy/testing/_private/parameterized.py
@@ -45,11 +45,18 @@ except ImportError:
from unittest import TestCase
-PY3 = sys.version_info[0] == 3
PY2 = sys.version_info[0] == 2
-if PY3:
+if PY2:
+ from types import InstanceType
+ lzip = zip
+ text_type = unicode
+ bytes_type = str
+ string_types = basestring,
+ def make_method(func, instance, type):
+ return MethodType(func, instance, type)
+else:
# Python 3 doesn't have an InstanceType, so just use a dummy type.
class InstanceType():
pass
@@ -61,14 +68,6 @@ if PY3:
if instance is None:
return func
return MethodType(func, instance)
-else:
- from types import InstanceType
- lzip = zip
- text_type = unicode
- bytes_type = str
- string_types = basestring,
- def make_method(func, instance, type):
- return MethodType(func, instance, type)
_param = namedtuple("param", "args kwargs")
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 97a5eac17..8a31fcf15 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -686,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, array2string, isnan, inf, bool_, errstate
+ from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -788,17 +788,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# np.ma.masked, which is falsy).
if cond != True:
n_mismatch = reduced.size - reduced.sum(dtype=intp)
- percent_mismatch = 100 * n_mismatch / ox.size
+ n_elements = flagged.size if flagged.ndim != 0 else reduced.size
+ percent_mismatch = 100 * n_mismatch / n_elements
remarks = [
'Mismatched elements: {} / {} ({:.3g}%)'.format(
- n_mismatch, ox.size, percent_mismatch)]
+ n_mismatch, n_elements, percent_mismatch)]
with errstate(invalid='ignore', divide='ignore'):
# ignore errors for non-numeric types
with contextlib.suppress(TypeError):
error = abs(x - y)
- max_abs_error = error.max()
- if error.dtype == 'object':
+ max_abs_error = max(error)
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max absolute difference: '
+ str(max_abs_error))
else:
@@ -807,8 +808,13 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# note: this definition of relative error matches that one
# used by assert_allclose (found in np.isclose)
- max_rel_error = (error / abs(y)).max()
- if error.dtype == 'object':
+ # Filter values where the divisor would be zero
+ nonzero = bool_(y != 0)
+ if all(~nonzero):
+ max_rel_error = array(inf)
+ else:
+ max_rel_error = max(error[nonzero] / abs(y[nonzero]))
+ if getattr(error, 'dtype', object_) == object_:
remarks.append('Max relative difference: '
+ str(max_rel_error))
else: