diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-08-17 17:47:58 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-17 17:47:58 -0500 |
commit | 18f338acb8ef4b7b7508d8451f458c34566a05f0 (patch) | |
tree | 5a321603251fc22e533ab85d784c4afbbd18c100 /numpy | |
parent | 9245def62e4747324be811f2d4f621a04213c131 (diff) | |
parent | 78efe63c4444b47330395c0f632cd538dbfd2c8b (diff) | |
download | numpy-18f338acb8ef4b7b7508d8451f458c34566a05f0.tar.gz |
Merge pull request #11756 from charris/fix-testing-utils
MAINT: Make assert_array_compare more generic.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/testing/_private/utils.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 0e2f8ba91..a3832fcde 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -687,6 +687,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, isnan, inf, bool_ + from numpy.core.fromnumeric import all as npall + x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -697,14 +699,21 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, return x.dtype.char in "Mm" def func_assert_same_pos(x, y, func=isnan, hasval='nan'): - """Handling nan/inf: combine results of running func on x and y, - checking that they are True at the same locations.""" - # Both the != True comparison here and the cast to bool_ at - # the end are done to deal with `masked`, which cannot be - # compared usefully, and for which .all() yields masked. + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + # Both the != True comparison here and the cast to bool_ at the end are + # done to deal with `masked`, which cannot be compared usefully, and + # for which np.all yields masked. The use of the function np.all is + # for back compatibility with ndarray subclasses that changed the + # return values of the all method. We are not committed to supporting + # such subclasses, but some used to work. x_id = func(x) y_id = func(y) - if (x_id == y_id).all() != True: + if npall(x_id == y_id) != True: msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' % (hasval), verbose=verbose, header=header, |