From 0d3aeb0dde468a15ad15a44017842115bc3313e4 Mon Sep 17 00:00:00 2001 From: David Cournapeau Date: Mon, 27 Jul 2009 15:10:30 +0000 Subject: Handle complex special values and negative zero. Complex with nan/inf are correctly handled in assert_equal, as well as negative zero. --- numpy/testing/tests/test_utils.py | 16 +++++++++++++ numpy/testing/utils.py | 47 +++++++++++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 4 deletions(-) (limited to 'numpy/testing') diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index e29edacff..240752dcf 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -145,6 +145,22 @@ class TestEqual(TestArrayEqual): self._assert_func('ab', 'ab') self._test_not_equal('ab', 'abb') + def test_complex_item(self): + self._assert_func(complex(1, 2), complex(1, 2)) + self._assert_func(complex(1, np.nan), complex(1, np.nan)) + self._test_not_equal(complex(1, np.nan), complex(1, 2)) + self._test_not_equal(complex(np.nan, 1), complex(1, np.nan)) + self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2)) + + def test_negative_zero(self): + self._test_not_equal(np.PZERO, np.NZERO) + + def test_complex(self): + x = np.array([complex(1, 2), complex(1, np.nan)]) + y = np.array([complex(1, 2), complex(1, 2)]) + self._assert_func(x, x) + self._test_not_equal(x, y) + class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_almost_equal diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 5103e28dc..2fc8729a6 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -215,12 +215,47 @@ def assert_equal(actual,desired,err_msg='',verbose=True): for k in range(len(desired)): assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg), verbose) return - from numpy.core import ndarray, isscalar + from numpy.core import ndarray, isscalar, signbit + from numpy.lib import iscomplexobj, real, imag if isinstance(actual, ndarray) or isinstance(desired, ndarray): return assert_array_equal(actual, desired, err_msg, verbose) msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError("Items are not equal:\n" \ + "ACTUAL: %s\n" \ + "DESIRED: %s\n" % (str(actual), str(desired))) + + # Inf/nan/negative zero handling + try: + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + # If one of desired/actual is not finite, handle it specially here: # check that both are nan if any is a nan, and test for equality # otherwise @@ -228,14 +263,15 @@ def assert_equal(actual,desired,err_msg='',verbose=True): isdesnan = gisnan(desired) isactnan = gisnan(actual) if isdesnan or isactnan: - # isscalar test to check so that [np.nan] != np.nan - if not (isdesnan and isactnan) \ - or (isscalar(isdesnan) != isscalar(isactnan)): + if not (isdesnan and isactnan): raise AssertionError(msg) else: if not desired == actual: raise AssertionError(msg) return + elif desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) # If TypeError or ValueError raised while using isnan and co, just handle # as before except TypeError: @@ -461,6 +497,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, verbose=verbose, header=header, names=('x', 'y')) raise AssertionError(msg) + # If only one item, it was a nan, so just return + if x.size == y.size == 1: + return val = comparison(x[~xnanid], y[~ynanid]) else: val = comparison(x,y) -- cgit v1.2.1