diff options
author | David Cournapeau <cournape@gmail.com> | 2009-07-28 06:31:46 +0000 |
---|---|---|
committer | David Cournapeau <cournape@gmail.com> | 2009-07-28 06:31:46 +0000 |
commit | 15f9f6911af9b3fd9a28e09d344d40bb721e60b9 (patch) | |
tree | bf316f9c641da284df06b183d95a34961c235630 /numpy/testing/utils.py | |
parent | 3feb77ee4e46962b6c9cfab2a37a0c1ef1ceda82 (diff) | |
download | numpy-15f9f6911af9b3fd9a28e09d344d40bb721e60b9.tar.gz |
BUG: handle inf/nan correctly in assert_array_almost_equal.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r-- | numpy/testing/utils.py | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 4e82fd510..a349ecf54 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -57,6 +57,23 @@ def gisfinite(x): raise TypeError("isfinite not supported for this type") return st +def gisinf(x): + """like isinf, but always raise an error if type not supported instead of + returning a TypeError object. + + Notes + ----- + isinf and other ufunc sometimes return a NotImplementedType object instead + of raising any exception. This function is a wrapper to make sure an + exception is always raised. + + This should be removed once this problem is solved at the Ufunc level.""" + from numpy.core import isinf + st = isinf(x) + if isinstance(st, types.NotImplementedType): + raise TypeError("isinf not supported for this type") + return st + def rand(*args): """Returns an array of random numbers with the given shape. @@ -683,9 +700,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ - from numpy.core import around, number, float_ + from numpy.core import around, number, float_, any from numpy.lib import issubdtype def compare(x, y): + try: + if any(gisinf(x)) or any( gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not xinfid == yinfid: + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except TypeError: + pass z = abs(x-y) if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays |