summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py36
1 files changed, 22 insertions, 14 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index a3832fcde..20a7dfd0b 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -19,7 +19,7 @@ from warnings import WarningMessage
import pprint
from numpy.core import(
- float32, empty, arange, array_repr, ndarray, isnat, array)
+ bool_, float32, empty, arange, array_repr, ndarray, isnat, array)
from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
@@ -352,7 +352,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail
try:
usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
- except ValueError:
+ except (ValueError, TypeError):
usecomplex = False
if usecomplex:
@@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ # original array for output formating
+ ox, oy = x, y
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -705,15 +708,20 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
at the same locations.
"""
- # Both the != True comparison here and the cast to bool_ at the end are
- # done to deal with `masked`, which cannot be compared usefully, and
- # for which np.all yields masked. The use of the function np.all is
- # for back compatibility with ndarray subclasses that changed the
- # return values of the all method. We are not committed to supporting
- # such subclasses, but some used to work.
x_id = func(x)
y_id = func(y)
- if npall(x_id == y_id) != True:
+ # We include work-arounds here to handle three types of slightly
+ # pathological ndarray subclasses:
+ # (1) all() on `masked` array scalars can return masked arrays, so we
+ # use != True
+ # (2) __eq__ on some ndarray subclasses returns Python booleans
+ # instead of element-wise comparisons, so we cast to bool_() and
+ # use isinstance(..., bool) checks
+ # (3) subclasses with bare-bones __array_function__ implemenations may
+ # not implement np.all(), so favor using the .all() method
+ # We are not committed to supporting such subclasses, but it's nice to
+ # support them if possible.
+ if bool_(x_id == y_id).all() != True:
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:'
% (hasval), verbose=verbose, header=header,
@@ -721,9 +729,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
# If there is a scalar, then here we know the array has the same
# flag as it everywhere, so we should return the scalar flag.
- if x_id.ndim == 0:
+ if isinstance(x_id, bool) or x_id.ndim == 0:
return bool_(x_id)
- elif y_id.ndim == 0:
+ elif isinstance(x_id, bool) or y_id.ndim == 0:
return bool_(y_id)
else:
return y_id
@@ -780,10 +788,10 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
# do not trigger a failure (np.ma.masked != True evaluates as
# np.ma.masked, which is falsy).
if cond != True:
- match = 100-100.0*reduced.count(1)/len(reduced)
- msg = build_err_msg([x, y],
+ mismatch = 100.0 * reduced.count(0) / ox.size
+ msg = build_err_msg([ox, oy],
err_msg
- + '\n(mismatch %s%%)' % (match,),
+ + '\n(mismatch %s%%)' % (mismatch,),
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)