summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorDavid Cournapeau <cournape@gmail.com>2009-07-28 06:31:46 +0000
committerDavid Cournapeau <cournape@gmail.com>2009-07-28 06:31:46 +0000
commit15f9f6911af9b3fd9a28e09d344d40bb721e60b9 (patch)
treebf316f9c641da284df06b183d95a34961c235630 /numpy
parent3feb77ee4e46962b6c9cfab2a37a0c1ef1ceda82 (diff)
downloadnumpy-15f9f6911af9b3fd9a28e09d344d40bb721e60b9.tar.gz
BUG: handle inf/nan correctly in assert_array_almost_equal.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/tests/test_utils.py1
-rw-r--r--numpy/testing/utils.py32
2 files changed, 32 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 448a2e3e7..0ecf0622d 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -209,6 +209,7 @@ class TestAlmostEqual(_GenericTest, unittest.TestCase):
def test_complex_item(self):
self._assert_func(complex(1, 2), complex(1, 2))
self._assert_func(complex(1, np.nan), complex(1, np.nan))
+ self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
self._test_not_equal(complex(1, np.nan), complex(1, 2))
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 4e82fd510..a349ecf54 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -57,6 +57,23 @@ def gisfinite(x):
raise TypeError("isfinite not supported for this type")
return st
+def gisinf(x):
+ """like isinf, but always raise an error if type not supported instead of
+ returning a TypeError object.
+
+ Notes
+ -----
+ isinf and other ufunc sometimes return a NotImplementedType object instead
+ of raising any exception. This function is a wrapper to make sure an
+ exception is always raised.
+
+ This should be removed once this problem is solved at the Ufunc level."""
+ from numpy.core import isinf
+ st = isinf(x)
+ if isinstance(st, types.NotImplementedType):
+ raise TypeError("isinf not supported for this type")
+ return st
+
def rand(*args):
"""Returns an array of random numbers with the given shape.
@@ -683,9 +700,22 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
y: array([ 1. , 2.33333, 5. ])
"""
- from numpy.core import around, number, float_
+ from numpy.core import around, number, float_, any
from numpy.lib import issubdtype
def compare(x, y):
+ try:
+ if any(gisinf(x)) or any( gisinf(y)):
+ xinfid = gisinf(x)
+ yinfid = gisinf(y)
+ if not xinfid == yinfid:
+ return False
+ # if one item, x and y is +- inf
+ if x.size == y.size == 1:
+ return x == y
+ x = x[~xinfid]
+ y = y[~yinfid]
+ except TypeError:
+ pass
z = abs(x-y)
if not issubdtype(z.dtype, number):
z = z.astype(float_) # handle object arrays