summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py76
1 files changed, 39 insertions, 37 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index b5a7e05c4..f54995870 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -15,7 +15,8 @@ import contextlib
from tempfile import mkdtemp, mkstemp
from unittest.case import SkipTest
-from numpy.core import float32, empty, arange, array_repr, ndarray
+from numpy.core import(
+ float32, empty, arange, array_repr, ndarray, isnat, array)
from numpy.lib.utils import deprecate
if sys.version_info[0] >= 3:
@@ -286,7 +287,7 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
return '\n'.join(msg)
-def assert_equal(actual,desired,err_msg='',verbose=True):
+def assert_equal(actual, desired, err_msg='', verbose=True):
"""
Raises an AssertionError if two objects are not equal.
@@ -369,12 +370,12 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
except AssertionError:
raise AssertionError(msg)
+ # isscalar test to check cases such as [np.nan] != np.nan
+ if isscalar(desired) != isscalar(actual):
+ raise AssertionError(msg)
+
# Inf/nan/negative zero handling
try:
- # isscalar test to check cases such as [np.nan] != np.nan
- if isscalar(desired) != isscalar(actual):
- raise AssertionError(msg)
-
# If one of desired/actual is not finite, handle it specially here:
# check that both are nan if any is a nan, and test for equality
# otherwise
@@ -396,14 +397,24 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
except (TypeError, ValueError, NotImplementedError):
pass
- # Explicitly use __eq__ for comparison, ticket #2552
- with suppress_warnings() as sup:
- # TODO: Better handling will to needed when change happens!
- sup.filter(DeprecationWarning, ".*NAT ==")
- sup.filter(FutureWarning, ".*NAT ==")
- if not (desired == actual):
+ try:
+ # If both are NaT (and have the same dtype -- datetime or timedelta)
+ # they are considered equal.
+ if (isnat(desired) == isnat(actual) and
+ array(desired).dtype.type == array(actual).dtype.type):
+ return
+ else:
raise AssertionError(msg)
+ # If TypeError or ValueError raised while using isnan and co, just handle
+ # as before
+ except (TypeError, ValueError, NotImplementedError):
+ pass
+
+ # Explicitly use __eq__ for comparison, ticket #2552
+ if not (desired == actual):
+ raise AssertionError(msg)
+
def print_assert_equal(test_string, actual, desired):
"""
@@ -674,33 +685,12 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
- def safe_comparison(*args, **kwargs):
- # There are a number of cases where comparing two arrays hits special
- # cases in array_richcompare, specifically around strings and void
- # dtypes. Basically, we just can't do comparisons involving these
- # types, unless both arrays have exactly the *same* type. So
- # e.g. you can apply == to two string arrays, or two arrays with
- # identical structured dtypes. But if you compare a non-string array
- # to a string array, or two arrays with non-identical structured
- # dtypes, or anything like that, then internally stuff blows up.
- # Currently, when things blow up, we just return a scalar False or
- # True. But we also emit a DeprecationWarning, b/c eventually we
- # should raise an error here. (Ideally we might even make this work
- # properly, but since that will require rewriting a bunch of how
- # ufuncs work then we are not counting on that.)
- #
- # The point of this little function is to let the DeprecationWarning
- # pass (or maybe eventually catch the errors and return False, I
- # dunno, that's a little trickier and we can figure that out when the
- # time comes).
- with suppress_warnings() as sup:
- sup.filter(DeprecationWarning, ".*==")
- sup.filter(FutureWarning, ".*==")
- return comparison(*args, **kwargs)
-
def isnumber(x):
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
+ def istime(x):
+ return x.dtype.char in "Mm"
+
def chk_same_position(x_id, y_id, hasval='nan'):
"""Handling nan/inf: check that x and y have the nan/inf at the same
locations."""
@@ -756,7 +746,19 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
if x.size == 0:
return
- val = safe_comparison(x, y)
+ elif istime(x) and istime(y):
+ # If one is datetime64 and the other timedelta64 there is no point
+ if equal_nan and x.dtype.type == y.dtype.type:
+ x_isnat, y_isnat = isnat(x), isnat(y)
+
+ if any(x_isnat) or any(y_isnat):
+ chk_same_position(x_isnat, y_isnat, hasval="NaT")
+
+ if any(x_isnat) or any(y_isnat):
+ x = x[~x_isnat]
+ y = y[~y_isnat]
+
+ val = comparison(x, y)
if isinstance(val, bool):
cond = val