diff options
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4527a51d9..6244b3ea7 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -592,6 +592,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) + def safe_comparison(*args, **kwargs): + # There are a number of cases where comparing two arrays hits special + # cases in array_richcompare, specifically around strings and void + # dtypes. Basically, we just can't do comparisons involving these + # types, unless both arrays have exactly the *same* type. So + # e.g. you can apply == to two string arrays, or two arrays with + # identical structured dtypes. But if you compare a non-string array + # to a string array, or two arrays with non-identical structured + # dtypes, or anything like that, then internally stuff blows up. + # Currently, when things blow up, we just return a scalar False or + # True. But we also emit a DeprecationWarning, b/c eventually we + # should raise an error here. (Ideally we might even make this work + # properly, but since that will require rewriting a bunch of how + # ufuncs work then we are not counting on that.) + # + # The point of this little function is to let the DeprecationWarning + # pass (or maybe eventually catch the errors and return False, I + # dunno, that's a little trickier and we can figure that out when the + # time comes). + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + return comparison(*args, **kwargs) + def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' @@ -641,11 +664,11 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, return if any(x_id): - val = comparison(x[~x_id], y[~y_id]) + val = safe_comparison(x[~x_id], y[~y_id]) else: - val = comparison(x, y) + val = safe_comparison(x, y) else: - val = comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val |