From 30c6bcab17dc43b9683ff79bca99f1b37b0f70e1 Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Mon, 20 Aug 2007 13:46:55 +0000 Subject: Fix allclose and add tests (based on a patch by Matthew Brett). --- numpy/core/numeric.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b4709b392..4c3047720 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -835,21 +835,16 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): """ x = array(a, copy=False) y = array(b, copy=False) - d1 = less_equal(absolute(x-y), atol + rtol * absolute(y)) xinf = isinf(x) - yinf = isinf(y) - if (not xinf.any() and not yinf.any()): - return d1.all() - d3 = (x[xinf] == y[yinf]) - d4 = (~xinf & ~yinf) - if d3.size < 2: - if d3.size==0: - return False - return d3 - if d3.all(): - return d1[d4].all() - else: + if not all(xinf == isinf(y)): + return False + if not any(xinf): + return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) + if not all(x[xinf] == y[xinf]): return False + x = x[~xinf] + y = y[~xinf] + return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) def array_equal(a1, a2): try: -- cgit v1.2.1