From 92aa4080f2a925ff6d03d0b9f80460a53d4c111c Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 26 Dec 2016 22:55:40 +0100 Subject: BUG: Remove warning filters from comparison assert functions These warning filters are not threadsafe, while it is nicer if the comparison functions can be called in a threaded environment since they do no explicite warning checking. Less filters are also cleaner in general, though may also mean that downstream projects see warnings that were previously hidden. --- numpy/testing/utils.py | 76 ++++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 37 deletions(-) (limited to 'numpy/testing/utils.py') diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index b5a7e05c4..f54995870 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -15,7 +15,8 @@ import contextlib from tempfile import mkdtemp, mkstemp from unittest.case import SkipTest -from numpy.core import float32, empty, arange, array_repr, ndarray +from numpy.core import( + float32, empty, arange, array_repr, ndarray, isnat, array) from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: @@ -286,7 +287,7 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', return '\n'.join(msg) -def assert_equal(actual,desired,err_msg='',verbose=True): +def assert_equal(actual, desired, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal. @@ -369,12 +370,12 @@ def assert_equal(actual,desired,err_msg='',verbose=True): except AssertionError: raise AssertionError(msg) + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + # Inf/nan/negative zero handling try: - # isscalar test to check cases such as [np.nan] != np.nan - if isscalar(desired) != isscalar(actual): - raise AssertionError(msg) - # If one of desired/actual is not finite, handle it specially here: # check that both are nan if any is a nan, and test for equality # otherwise @@ -396,14 +397,24 @@ def assert_equal(actual,desired,err_msg='',verbose=True): except (TypeError, ValueError, NotImplementedError): pass - # Explicitly use __eq__ for comparison, ticket #2552 - with suppress_warnings() as sup: - # TODO: Better handling will to needed when change happens! - sup.filter(DeprecationWarning, ".*NAT ==") - sup.filter(FutureWarning, ".*NAT ==") - if not (desired == actual): + try: + # If both are NaT (and have the same dtype -- datetime or timedelta) + # they are considered equal. + if (isnat(desired) == isnat(actual) and + array(desired).dtype.type == array(actual).dtype.type): + return + else: raise AssertionError(msg) + # If TypeError or ValueError raised while using isnan and co, just handle + # as before + except (TypeError, ValueError, NotImplementedError): + pass + + # Explicitly use __eq__ for comparison, ticket #2552 + if not (desired == actual): + raise AssertionError(msg) + def print_assert_equal(test_string, actual, desired): """ @@ -674,33 +685,12 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) - def safe_comparison(*args, **kwargs): - # There are a number of cases where comparing two arrays hits special - # cases in array_richcompare, specifically around strings and void - # dtypes. Basically, we just can't do comparisons involving these - # types, unless both arrays have exactly the *same* type. So - # e.g. you can apply == to two string arrays, or two arrays with - # identical structured dtypes. But if you compare a non-string array - # to a string array, or two arrays with non-identical structured - # dtypes, or anything like that, then internally stuff blows up. - # Currently, when things blow up, we just return a scalar False or - # True. But we also emit a DeprecationWarning, b/c eventually we - # should raise an error here. (Ideally we might even make this work - # properly, but since that will require rewriting a bunch of how - # ufuncs work then we are not counting on that.) - # - # The point of this little function is to let the DeprecationWarning - # pass (or maybe eventually catch the errors and return False, I - # dunno, that's a little trickier and we can figure that out when the - # time comes). - with suppress_warnings() as sup: - sup.filter(DeprecationWarning, ".*==") - sup.filter(FutureWarning, ".*==") - return comparison(*args, **kwargs) - def isnumber(x): return x.dtype.char in '?bhilqpBHILQPefdgFDG' + def istime(x): + return x.dtype.char in "Mm" + def chk_same_position(x_id, y_id, hasval='nan'): """Handling nan/inf: check that x and y have the nan/inf at the same locations.""" @@ -756,7 +746,19 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, if x.size == 0: return - val = safe_comparison(x, y) + elif istime(x) and istime(y): + # If one is datetime64 and the other timedelta64 there is no point + if equal_nan and x.dtype.type == y.dtype.type: + x_isnat, y_isnat = isnat(x), isnat(y) + + if any(x_isnat) or any(y_isnat): + chk_same_position(x_isnat, y_isnat, hasval="NaT") + + if any(x_isnat) or any(y_isnat): + x = x[~x_isnat] + y = y[~y_isnat] + + val = comparison(x, y) if isinstance(val, bool): cond = val -- cgit v1.2.1