summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 846429016..59c10151d 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -30,6 +30,7 @@ from umath import *
import numerictypes
from numerictypes import *
+
if sys.version_info[0] < 3:
__all__.extend(['getbuffer', 'newbuffer'])
@@ -2011,18 +2012,26 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8):
False
"""
+ def within_tol(x, y, atol, rtol):
+ err = seterr(invalid='ignore')
+ try:
+ result = less_equal(abs(x-y), atol + rtol * abs(y))
+ finally:
+ seterr(**err)
+ return result
+
x = array(a, copy=False, ndmin=1)
y = array(b, copy=False, ndmin=1)
xinf = isinf(x)
if not all(xinf == isinf(y)):
return False
if not any(xinf):
- return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
+ return all(within_tol(x, y, atol, rtol))
if not all(x[xinf] == y[xinf]):
return False
x = x[~xinf]
y = y[~xinf]
- return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
+ return all(within_tol(x, y, atol, rtol))
def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
"""
@@ -2084,10 +2093,15 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
array([True, True])
"""
def within_tol(x, y, atol, rtol):
- result = less_equal(abs(x-y), atol + rtol * abs(y))
+ err = seterr(invalid='ignore')
+ try:
+ result = less_equal(abs(x-y), atol + rtol * abs(y))
+ finally:
+ seterr(**err)
if isscalar(a) and isscalar(b):
result = bool(result)
return result
+
x = array(a, copy=False, subok=True, ndmin=1)
y = array(b, copy=False, subok=True, ndmin=1)
xfin = isfinite(x)