diff options
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index a3832fcde..20a7dfd0b 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -19,7 +19,7 @@ from warnings import WarningMessage import pprint from numpy.core import( - float32, empty, arange, array_repr, ndarray, isnat, array) + bool_, float32, empty, arange, array_repr, ndarray, isnat, array) from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: @@ -352,7 +352,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True): # XXX: catch ValueError for subclasses of ndarray where iscomplex fail try: usecomplex = iscomplexobj(actual) or iscomplexobj(desired) - except ValueError: + except (ValueError, TypeError): usecomplex = False if usecomplex: @@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + # original array for output formating + ox, oy = x, y + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -705,15 +708,20 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, at the same locations. """ - # Both the != True comparison here and the cast to bool_ at the end are - # done to deal with `masked`, which cannot be compared usefully, and - # for which np.all yields masked. The use of the function np.all is - # for back compatibility with ndarray subclasses that changed the - # return values of the all method. We are not committed to supporting - # such subclasses, but some used to work. x_id = func(x) y_id = func(y) - if npall(x_id == y_id) != True: + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implemenations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if bool_(x_id == y_id).all() != True: msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' % (hasval), verbose=verbose, header=header, @@ -721,9 +729,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, raise AssertionError(msg) # If there is a scalar, then here we know the array has the same # flag as it everywhere, so we should return the scalar flag. - if x_id.ndim == 0: + if isinstance(x_id, bool) or x_id.ndim == 0: return bool_(x_id) - elif y_id.ndim == 0: + elif isinstance(x_id, bool) or y_id.ndim == 0: return bool_(y_id) else: return y_id @@ -780,10 +788,10 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, # do not trigger a failure (np.ma.masked != True evaluates as # np.ma.masked, which is falsy). if cond != True: - match = 100-100.0*reduced.count(1)/len(reduced) - msg = build_err_msg([x, y], + mismatch = 100.0 * reduced.count(0) / ox.size + msg = build_err_msg([ox, oy], err_msg - + '\n(mismatch %s%%)' % (match,), + + '\n(mismatch %s%%)' % (mismatch,), verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) |