diff options
author | David Cournapeau <cournape@gmail.com> | 2009-07-27 08:06:48 +0000 |
---|---|---|
committer | David Cournapeau <cournape@gmail.com> | 2009-07-27 08:06:48 +0000 |
commit | f49d6d36b666eaff6a82614badf452df1844e33b (patch) | |
tree | f8b060ad4d39c1565ad83878e3d8f599ee9f054a | |
parent | a711484bf1bda459b67aad86c11556bdfc69e279 (diff) | |
download | numpy-f49d6d36b666eaff6a82614badf452df1844e33b.tar.gz |
Handle nan and inf in assert_almost_equal.
-rw-r--r-- | numpy/testing/tests/test_utils.py | 21 | ||||
-rw-r--r-- | numpy/testing/utils.py | 49 |
2 files changed, 69 insertions, 1 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index aabdc88a0..1703ddbc7 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -127,10 +127,29 @@ class TestEqual(_GenericTest, unittest.TestCase): self._test_not_equal(c, b) -class TestAlmostEqual(_GenericTest, unittest.TestCase): +class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_almost_equal +class TestAlmostEqual(_GenericTest, unittest.TestCase): + def setUp(self): + self._assert_func = assert_almost_equal + + def test_nan_item(self): + self._assert_func(np.nan, np.nan) + self.failUnlessRaises(AssertionError, + lambda : self._assert_func(np.nan, 1)) + self.failUnlessRaises(AssertionError, + lambda : self._assert_func(np.nan, np.inf)) + self.failUnlessRaises(AssertionError, + lambda : self._assert_func(np.inf, np.nan)) + + def test_inf_item(self): + self._assert_func(np.inf, np.inf) + self._assert_func(-np.inf, -np.inf) + + def test_simple_item(self): + self._test_not_equal(1, 2) class TestRaises(unittest.TestCase): def setUp(self): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index ba9b16b18..b18cef6e3 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -6,6 +6,7 @@ import os import sys import re import operator +import types from nosetester import import_nose __all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal', @@ -22,6 +23,39 @@ def assert_(val, msg='') : if not val : raise AssertionError(msg) +def gisnan(x): + """like isnan, but always raise an error if type not supported instead of + returning a TypeError object. + + Notes + ----- + isnan and other ufunc sometimes return a NotImplementedType object instead + of raising any exception. This function is a wrapper to make sure an + exception is always raised. + + This should be removed once this problem is solved at the Ufunc level.""" + from numpy.core import isnan + st = isnan(x) + if isinstance(st, types.NotImplementedType): + raise TypeError("isnan not supported for this type") + return st + +def gisfinite(x): + """like isfinite, but always raise an error if type not supported instead of + returning a TypeError object. + + Notes + ----- + isfinite and other ufunc sometimes return a NotImplementedType object instead + of raising any exception. This function is a wrapper to make sure an + exception is always raised. + + This should be removed once this problem is solved at the Ufunc level.""" + from numpy.core import isfinite + st = isfinite(x) + if isinstance(st, types.NotImplementedType): + raise TypeError("isfinite not supported for this type") + return st def rand(*args): """Returns an array of random numbers with the given shape. @@ -261,6 +295,21 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): if isinstance(actual, ndarray) or isinstance(desired, ndarray): return assert_array_almost_equal(actual, desired, decimal, err_msg) msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except TypeError: + pass if round(abs(desired - actual),decimal) != 0 : raise AssertionError(msg) |