summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-08-17 17:47:58 -0500
committerGitHub <noreply@github.com>2018-08-17 17:47:58 -0500
commit18f338acb8ef4b7b7508d8451f458c34566a05f0 (patch)
tree5a321603251fc22e533ab85d784c4afbbd18c100 /numpy
parent9245def62e4747324be811f2d4f621a04213c131 (diff)
parent78efe63c4444b47330395c0f632cd538dbfd2c8b (diff)
downloadnumpy-18f338acb8ef4b7b7508d8451f458c34566a05f0.tar.gz
Merge pull request #11756 from charris/fix-testing-utils
MAINT: Make assert_array_compare more generic.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/_private/utils.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 0e2f8ba91..a3832fcde 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -687,6 +687,8 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, isnan, inf, bool_
+ from numpy.core.fromnumeric import all as npall
+
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -697,14 +699,21 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
return x.dtype.char in "Mm"
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
- """Handling nan/inf: combine results of running func on x and y,
- checking that they are True at the same locations."""
- # Both the != True comparison here and the cast to bool_ at
- # the end are done to deal with `masked`, which cannot be
- # compared usefully, and for which .all() yields masked.
+ """Handling nan/inf.
+
+ Combine results of running func on x and y, checking that they are True
+ at the same locations.
+
+ """
+ # Both the != True comparison here and the cast to bool_ at the end are
+ # done to deal with `masked`, which cannot be compared usefully, and
+ # for which np.all yields masked. The use of the function np.all is
+ # for back compatibility with ndarray subclasses that changed the
+ # return values of the all method. We are not committed to supporting
+ # such subclasses, but some used to work.
x_id = func(x)
y_id = func(y)
- if (x_id == y_id).all() != True:
+ if npall(x_id == y_id) != True:
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:'
% (hasval), verbose=verbose, header=header,