summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/_private/utils.py67
-rw-r--r--numpy/testing/tests/test_utils.py12
2 files changed, 41 insertions, 38 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index a7935f175..5b7e35366 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -685,7 +685,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, isinf, any, inf
+ from numpy.core import array, isnan, any, inf, ndim
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -695,9 +695,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
def istime(x):
return x.dtype.char in "Mm"
- def chk_same_position(x_id, y_id, hasval='nan'):
- """Handling nan/inf: check that x and y have the nan/inf at the same
- locations."""
+ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
+ """Handling nan/inf: combine results of running func on x and y,
+ checking that they are True at the same locations."""
+ x_id = func(x)
+ y_id = func(y)
+ if not any(x_id) and not any(y_id):
+ return False
+
try:
assert_array_equal(x_id, y_id)
except AssertionError:
@@ -706,6 +711,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
% (hasval), verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
+ # If there is a scalar, then here we know the array has the same
+ # flag as it everywhere, so we should return the scalar flag.
+ return x_id if x_id.ndim == 0 else y_id
try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
@@ -718,49 +726,32 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
+ flagged = False
if isnumber(x) and isnumber(y):
- has_nan = has_inf = False
if equal_nan:
- x_isnan, y_isnan = isnan(x), isnan(y)
- # Validate that NaNs are in the same place
- has_nan = any(x_isnan) or any(y_isnan)
- if has_nan:
- chk_same_position(x_isnan, y_isnan, hasval='nan')
+ flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
if equal_inf:
- x_isinf, y_isinf = isinf(x), isinf(y)
- # Validate that infinite values are in the same place
- has_inf = any(x_isinf) or any(y_isinf)
- if has_inf:
- # Check +inf and -inf separately, since they are different
- chk_same_position(x == +inf, y == +inf, hasval='+inf')
- chk_same_position(x == -inf, y == -inf, hasval='-inf')
-
- if has_nan and has_inf:
- x = x[~(x_isnan | x_isinf)]
- y = y[~(y_isnan | y_isinf)]
- elif has_nan:
- x = x[~x_isnan]
- y = y[~y_isnan]
- elif has_inf:
- x = x[~x_isinf]
- y = y[~y_isinf]
-
- # Only do the comparison if actual values are left
- if x.size == 0:
- return
+ flagged |= func_assert_same_pos(x, y,
+ func=lambda xy: xy == +inf,
+ hasval='+inf')
+ flagged |= func_assert_same_pos(x, y,
+ func=lambda xy: xy == -inf,
+ hasval='-inf')
elif istime(x) and istime(y):
# If one is datetime64 and the other timedelta64 there is no point
if equal_nan and x.dtype.type == y.dtype.type:
- x_isnat, y_isnat = isnat(x), isnat(y)
+ flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
- if any(x_isnat) or any(y_isnat):
- chk_same_position(x_isnat, y_isnat, hasval="NaT")
-
- if any(x_isnat) or any(y_isnat):
- x = x[~x_isnat]
- y = y[~y_isnat]
+ if ndim(flagged):
+ x, y = x[~flagged], y[~flagged]
+ # Only do the comparison if actual values are left
+ if x.size == 0:
+ return
+ elif flagged:
+ # no sense doing comparison if everything is flagged.
+ return
val = comparison(x, y)
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 602cdf5f2..235b3750c 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -151,6 +151,17 @@ class TestArrayEqual(_GenericTest):
self._test_not_equal(c, b)
assert_equal(len(l), 1)
+ def test_masked_nan_inf(self):
+ # Regression test for gh-11121
+ a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
+ b = np.array([3., np.nan, 6.5])
+ self._test_equal(a, b)
+ self._test_equal(b, a)
+ a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
+ b = np.array([np.inf, 4., 6.5])
+ self._test_equal(a, b)
+ self._test_equal(b, a)
+
class TestBuildErrorMessage(object):
@@ -650,6 +661,7 @@ class TestArrayAssertLess(object):
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
self._assert_func(-ainf, x)
+
@pytest.mark.skip(reason="The raises decorator depends on Nose")
class TestRaises(object):