summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py38
1 files changed, 20 insertions, 18 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 859a0705b..7858eefac 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -666,9 +666,11 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
- header='', precision=6, equal_nan=True):
+ header='', precision=6, equal_nan=True,
+ equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
- from numpy.core import array, isnan, isinf, any, all, inf
+ from numpy.core import array, isnan, isinf, any, all, inf, zeros_like
+ from numpy.core.numerictypes import bool_
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
@@ -724,25 +726,24 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
raise AssertionError(msg)
if isnumber(x) and isnumber(y):
+ x_id, y_id = zeros_like(x, dtype=bool_), zeros_like(y, dtype=bool_)
if equal_nan:
x_isnan, y_isnan = isnan(x), isnan(y)
# Validate that NaNs are in the same place
if any(x_isnan) or any(y_isnan):
chk_same_position(x_isnan, y_isnan, hasval='nan')
-
- x_isinf, y_isinf = isinf(x), isinf(y)
-
- # Validate that infinite values are in the same place
- if any(x_isinf) or any(y_isinf):
- # Check +inf and -inf separately, since they are different
- chk_same_position(x == +inf, y == +inf, hasval='+inf')
- chk_same_position(x == -inf, y == -inf, hasval='-inf')
-
- # Combine all the special values
- x_id, y_id = x_isinf, y_isinf
- if equal_nan:
- x_id |= x_isnan
- y_id |= y_isnan
+ x_id |= x_isnan
+ y_id |= y_isnan
+
+ if equal_inf:
+ x_isinf, y_isinf = isinf(x), isinf(y)
+ # Validate that infinite values are in the same place
+ if any(x_isinf) or any(y_isinf):
+ # Check +inf and -inf separately, since they are different
+ chk_same_position(x == +inf, y == +inf, hasval='+inf')
+ chk_same_position(x == -inf, y == -inf, hasval='-inf')
+ x_id |= x_isinf
+ y_id |= y_isinf
# Only do the comparison if actual values are left
if all(x_id):
@@ -930,7 +931,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
if npany(gisinf(x)) or npany( gisinf(y)):
xinfid = gisinf(x)
yinfid = gisinf(y)
- if not xinfid == yinfid:
+ if not (xinfid == yinfid).all():
return False
# if one item, x and y is +- inf
if x.size == y.size == 1:
@@ -1025,7 +1026,8 @@ def assert_array_less(x, y, err_msg='', verbose=True):
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
verbose=verbose,
- header='Arrays are not less-ordered')
+ header='Arrays are not less-ordered',
+ equal_inf=False)
def runstring(astr, dict):