summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorchebee7i <chebee7i@gmail.com>2015-02-20 21:31:42 -0600
committerchebee7i <chebee7i@gmail.com>2015-02-20 21:31:42 -0600
commit2ae0529a5935c0eff0f72bb9f86d063537be250a (patch)
tree08fd35883e557632683b1cad79cd6c17c9beddf4 /numpy/core/numeric.py
parentd770034969e35907e7497d5fe9053df4bdac2fd2 (diff)
downloadnumpy-2ae0529a5935c0eff0f72bb9f86d063537be250a.tar.gz
MAINT: Make allclose use isclose.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py43
1 files changed, 7 insertions, 36 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index a464562c4..1519abcd4 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -2206,41 +2206,6 @@ def identity(n, dtype=None):
from numpy import eye
return eye(n, dtype=dtype)
-def _allclose_points(a, b, rtol=1.e-5, atol=1.e-8):
- """
- This is the point-wise inner calculation of 'allclose', which is subtly
- different from 'isclose'. Provided as a comparison routine for use in
- testing.assert_allclose.
- See those routines for further details.
-
- """
- x = array(a, copy=False, ndmin=1)
- y = array(b, copy=False, ndmin=1)
-
- # make sure y is an inexact type to avoid abs(MIN_INT); will cause
- # casting of x later.
- dtype = multiarray.result_type(y, 1.)
- y = array(y, dtype=dtype, copy=False)
-
- xinf = isinf(x)
- yinf = isinf(y)
- if any(xinf) or any(yinf):
- # Check that x and y have inf's only in the same positions
- if not all(xinf == yinf):
- return False
- # Check that sign of inf's in x and y is the same
- if not all(x[xinf] == y[xinf]):
- return False
-
- x = x[~xinf]
- y = y[~xinf]
-
- # ignore invalid fpe's
- with errstate(invalid='ignore'):
- r = less_equal(abs(x - y), atol + rtol * abs(y))
-
- return r
-
def allclose(a, b, rtol=1.e-5, atol=1.e-8):
"""
Returns True if two arrays are element-wise equal within a tolerance.
@@ -2296,7 +2261,7 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8):
False
"""
- return all(_allclose_points(a, b, rtol=rtol, atol=atol))
+ return all(isclose(a, b, rtol=rtol, atol=atol))
def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
"""
@@ -2366,6 +2331,12 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
x = array(a, copy=False, subok=True, ndmin=1)
y = array(b, copy=False, subok=True, ndmin=1)
+
+ # make sure y is an inexact type to avoid abs(MIN_INT); will cause
+ # casting of x later. Make sure to allow subclasses (e.g., for numpy.ma).
+ dtype = multiarray.result_type(y, 1.)
+ y = array(y, dtype=dtype, copy=False, subok=True)
+
xfin = isfinite(x)
yfin = isfinite(y)
if all(xfin) and all(yfin):