diff options
author | David Cournapeau <cournape@gmail.com> | 2009-07-27 15:09:55 +0000 |
---|---|---|
committer | David Cournapeau <cournape@gmail.com> | 2009-07-27 15:09:55 +0000 |
commit | 60fd0b786909db327fe4f3e925047e2c2f58f2d2 (patch) | |
tree | aabfff921a44ca4a3e64061ed848d17d5dcb8d10 | |
parent | 0259a4972a81295353407dfaf4ba5adcb34a42db (diff) | |
download | numpy-60fd0b786909db327fe4f3e925047e2c2f58f2d2.tar.gz |
Handle nan and inf in assert_equal.
-rw-r--r-- | numpy/testing/tests/test_utils.py | 20 | ||||
-rw-r--r-- | numpy/testing/utils.py | 25 |
2 files changed, 43 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index fd07d3985..e29edacff 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -48,7 +48,7 @@ class _GenericTest(object): a = np.array([1, 1], dtype=np.object) self._test_equal(a, 1) -class TestEqual(_GenericTest, unittest.TestCase): +class TestArrayEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_equal @@ -126,6 +126,24 @@ class TestEqual(_GenericTest, unittest.TestCase): self._test_not_equal(c, b) +class TestEqual(TestArrayEqual): + def setUp(self): + self._assert_func = assert_equal + + def test_nan_items(self): + self._assert_func(np.nan, np.nan) + self._assert_func([np.nan], [np.nan]) + self._test_not_equal(np.nan, [np.nan]) + self._test_not_equal(np.nan, 1) + + def test_inf_items(self): + self._assert_func(np.inf, np.inf) + self._assert_func([np.inf], [np.inf]) + self._test_not_equal(np.inf, [np.inf]) + + def test_non_numeric(self): + self._assert_func('ab', 'ab') + self._test_not_equal('ab', 'abb') class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): def setUp(self): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 88d613e2c..5103e28dc 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -215,10 +215,33 @@ def assert_equal(actual,desired,err_msg='',verbose=True): for k in range(len(desired)): assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg), verbose) return - from numpy.core import ndarray + from numpy.core import ndarray, isscalar if isinstance(actual, ndarray) or isinstance(desired, ndarray): return assert_array_equal(actual, desired, err_msg, verbose) msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan or isactnan: + # isscalar test to check so that [np.nan] != np.nan + if not (isdesnan and isactnan) \ + or (isscalar(isdesnan) != isscalar(isactnan)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + # If TypeError or ValueError raised while using isnan and co, just handle + # as before + except TypeError: + pass + except ValueError: + pass if desired != actual : raise AssertionError(msg) |