summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorDavid Cournapeau <cournape@gmail.com>2009-07-27 15:09:55 +0000
committerDavid Cournapeau <cournape@gmail.com>2009-07-27 15:09:55 +0000
commit60fd0b786909db327fe4f3e925047e2c2f58f2d2 (patch)
treeaabfff921a44ca4a3e64061ed848d17d5dcb8d10 /numpy
parent0259a4972a81295353407dfaf4ba5adcb34a42db (diff)
downloadnumpy-60fd0b786909db327fe4f3e925047e2c2f58f2d2.tar.gz
Handle nan and inf in assert_equal.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/testing/tests/test_utils.py20
-rw-r--r--numpy/testing/utils.py25
2 files changed, 43 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index fd07d3985..e29edacff 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -48,7 +48,7 @@ class _GenericTest(object):
a = np.array([1, 1], dtype=np.object)
self._test_equal(a, 1)
-class TestEqual(_GenericTest, unittest.TestCase):
+class TestArrayEqual(_GenericTest, unittest.TestCase):
def setUp(self):
self._assert_func = assert_array_equal
@@ -126,6 +126,24 @@ class TestEqual(_GenericTest, unittest.TestCase):
self._test_not_equal(c, b)
+class TestEqual(TestArrayEqual):
+ def setUp(self):
+ self._assert_func = assert_equal
+
+ def test_nan_items(self):
+ self._assert_func(np.nan, np.nan)
+ self._assert_func([np.nan], [np.nan])
+ self._test_not_equal(np.nan, [np.nan])
+ self._test_not_equal(np.nan, 1)
+
+ def test_inf_items(self):
+ self._assert_func(np.inf, np.inf)
+ self._assert_func([np.inf], [np.inf])
+ self._test_not_equal(np.inf, [np.inf])
+
+ def test_non_numeric(self):
+ self._assert_func('ab', 'ab')
+ self._test_not_equal('ab', 'abb')
class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
def setUp(self):
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 88d613e2c..5103e28dc 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -215,10 +215,33 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
for k in range(len(desired)):
assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg), verbose)
return
- from numpy.core import ndarray
+ from numpy.core import ndarray, isscalar
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
return assert_array_equal(actual, desired, err_msg, verbose)
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
+
+ try:
+ # If one of desired/actual is not finite, handle it specially here:
+ # check that both are nan if any is a nan, and test for equality
+ # otherwise
+ if not (gisfinite(desired) and gisfinite(actual)):
+ isdesnan = gisnan(desired)
+ isactnan = gisnan(actual)
+ if isdesnan or isactnan:
+ # isscalar test to check so that [np.nan] != np.nan
+ if not (isdesnan and isactnan) \
+ or (isscalar(isdesnan) != isscalar(isactnan)):
+ raise AssertionError(msg)
+ else:
+ if not desired == actual:
+ raise AssertionError(msg)
+ return
+ # If TypeError or ValueError raised while using isnan and co, just handle
+ # as before
+ except TypeError:
+ pass
+ except ValueError:
+ pass
if desired != actual :
raise AssertionError(msg)