summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py29
1 files changed, 26 insertions, 3 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 4527a51d9..6244b3ea7 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
+ def safe_comparison(*args, **kwargs):
+ # There are a number of cases where comparing two arrays hits special
+ # cases in array_richcompare, specifically around strings and void
+ # dtypes. Basically, we just can't do comparisons involving these
+ # types, unless both arrays have exactly the *same* type. So
+ # e.g. you can apply == to two string arrays, or two arrays with
+ # identical structured dtypes. But if you compare a non-string array
+ # to a string array, or two arrays with non-identical structured
+ # dtypes, or anything like that, then internally stuff blows up.
+ # Currently, when things blow up, we just return a scalar False or
+ # True. But we also emit a DeprecationWarning, b/c eventually we
+ # should raise an error here. (Ideally we might even make this work
+ # properly, but since that will require rewriting a bunch of how
+ # ufuncs work then we are not counting on that.)
+ #
+ # The point of this little function is to let the DeprecationWarning
+ # pass (or maybe eventually catch the errors and return False, I
+ # dunno, that's a little trickier and we can figure that out when the
+ # time comes).
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ return comparison(*args, **kwargs)
+
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
@@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
return
if any(x_id):
- val = comparison(x[~x_id], y[~y_id])
+ val = safe_comparison(x[~x_id], y[~y_id])
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
else:
- val = comparison(x, y)
+ val = safe_comparison(x, y)
if isinstance(val, bool):
cond = val