summaryrefslogtreecommitdiff
path: root/numpy/testing/_private
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-05-19 12:21:47 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-05-27 08:09:38 -0400
commit5718b3306303b4899c017641a9282714dcf8c9b8 (patch)
tree4a9dfd8c1f7ee5fd3dcdbb1fd48b0a1b64a3ccc3 /numpy/testing/_private
parent84f582f25e168dbfd59b3be470bc8ebc46ee2d92 (diff)
downloadnumpy-5718b3306303b4899c017641a9282714dcf8c9b8.tar.gz
BUG,MAINT: Ensure masked elements can be tested against nan and inf.
The removal of nan and inf from arrays that are compared using test routines like assert_array_equal treated the two arrays separately, which for masked arrays meant that some elements would not be removed when they should have been. This PR corrects this.
Diffstat (limited to 'numpy/testing/_private')
-rw-r--r--numpy/testing/_private/utils.py67
1 files changed, 29 insertions, 38 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index a7935f175..5b7e35366 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -685,7 +685,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, isinf, any, inf
+ from numpy.core import array, isnan, any, inf, ndim
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -695,9 +695,14 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
def istime(x):
return x.dtype.char in "Mm"
- def chk_same_position(x_id, y_id, hasval='nan'):
- """Handling nan/inf: check that x and y have the nan/inf at the same
- locations."""
+ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
+ """Handling nan/inf: combine results of running func on x and y,
+ checking that they are True at the same locations."""
+ x_id = func(x)
+ y_id = func(y)
+ if not any(x_id) and not any(y_id):
+ return False
+
try:
assert_array_equal(x_id, y_id)
except AssertionError:
@@ -706,6 +711,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
% (hasval), verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
+ # If there is a scalar, then here we know the array has the same
+ # flag as it everywhere, so we should return the scalar flag.
+ return x_id if x_id.ndim == 0 else y_id
try:
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
@@ -718,49 +726,32 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
+ flagged = False
if isnumber(x) and isnumber(y):
- has_nan = has_inf = False
if equal_nan:
- x_isnan, y_isnan = isnan(x), isnan(y)
- # Validate that NaNs are in the same place
- has_nan = any(x_isnan) or any(y_isnan)
- if has_nan:
- chk_same_position(x_isnan, y_isnan, hasval='nan')
+ flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
if equal_inf:
- x_isinf, y_isinf = isinf(x), isinf(y)
- # Validate that infinite values are in the same place
- has_inf = any(x_isinf) or any(y_isinf)
- if has_inf:
- # Check +inf and -inf separately, since they are different
- chk_same_position(x == +inf, y == +inf, hasval='+inf')
- chk_same_position(x == -inf, y == -inf, hasval='-inf')
-
- if has_nan and has_inf:
- x = x[~(x_isnan | x_isinf)]
- y = y[~(y_isnan | y_isinf)]
- elif has_nan:
- x = x[~x_isnan]
- y = y[~y_isnan]
- elif has_inf:
- x = x[~x_isinf]
- y = y[~y_isinf]
-
- # Only do the comparison if actual values are left
- if x.size == 0:
- return
+ flagged |= func_assert_same_pos(x, y,
+ func=lambda xy: xy == +inf,
+ hasval='+inf')
+ flagged |= func_assert_same_pos(x, y,
+ func=lambda xy: xy == -inf,
+ hasval='-inf')
elif istime(x) and istime(y):
# If one is datetime64 and the other timedelta64 there is no point
if equal_nan and x.dtype.type == y.dtype.type:
- x_isnat, y_isnat = isnat(x), isnat(y)
+ flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
- if any(x_isnat) or any(y_isnat):
- chk_same_position(x_isnat, y_isnat, hasval="NaT")
-
- if any(x_isnat) or any(y_isnat):
- x = x[~x_isnat]
- y = y[~y_isnat]
+ if ndim(flagged):
+ x, y = x[~flagged], y[~flagged]
+ # Only do the comparison if actual values are left
+ if x.size == 0:
+ return
+ elif flagged:
+ # no sense doing comparison if everything is flagged.
+ return
val = comparison(x, y)