diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 846429016..59c10151d 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -30,6 +30,7 @@ from umath import * import numerictypes from numerictypes import * + if sys.version_info[0] < 3: __all__.extend(['getbuffer', 'newbuffer']) @@ -2011,18 +2012,26 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): False """ + def within_tol(x, y, atol, rtol): + err = seterr(invalid='ignore') + try: + result = less_equal(abs(x-y), atol + rtol * abs(y)) + finally: + seterr(**err) + return result + x = array(a, copy=False, ndmin=1) y = array(b, copy=False, ndmin=1) xinf = isinf(x) if not all(xinf == isinf(y)): return False if not any(xinf): - return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) + return all(within_tol(x, y, atol, rtol)) if not all(x[xinf] == y[xinf]): return False x = x[~xinf] y = y[~xinf] - return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) + return all(within_tol(x, y, atol, rtol)) def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): """ @@ -2084,10 +2093,15 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): array([True, True]) """ def within_tol(x, y, atol, rtol): - result = less_equal(abs(x-y), atol + rtol * abs(y)) + err = seterr(invalid='ignore') + try: + result = less_equal(abs(x-y), atol + rtol * abs(y)) + finally: + seterr(**err) if isscalar(a) and isscalar(b): result = bool(result) return result + x = array(a, copy=False, subok=True, ndmin=1) y = array(b, copy=False, subok=True, ndmin=1) xfin = isfinite(x) |