diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 59c10151d..a9a42ea09 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -2012,26 +2012,26 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): False """ - def within_tol(x, y, atol, rtol): - err = seterr(invalid='ignore') - try: - result = less_equal(abs(x-y), atol + rtol * abs(y)) - finally: - seterr(**err) - return result - x = array(a, copy=False, ndmin=1) y = array(b, copy=False, ndmin=1) - xinf = isinf(x) - if not all(xinf == isinf(y)): - return False - if not any(xinf): - return all(within_tol(x, y, atol, rtol)) - if not all(x[xinf] == y[xinf]): + + if any(isnan(x)) or any(isnan(y)): return False - x = x[~xinf] - y = y[~xinf] - return all(within_tol(x, y, atol, rtol)) + + xinf = isinf(x) + yinf = isinf(y) + if any(xinf) or any(yinf): + # Check that x and y have inf's only in the same positions + if not all(xinf == yinf): + return False + # Check that sign of inf's in x and y is the same + if not all(x[xinf] == y[xinf]): + return False + + x = x[~xinf] + y = y[~xinf] + + return all(less_equal(abs(x-y), atol + rtol * abs(y))) def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): """ |