summaryrefslogtreecommitdiff
path: root/numpy/testing/utils.py
diff options
context:
space:
mode:
authorDavid Cournapeau <cournape@gmail.com>2009-07-27 08:06:48 +0000
committerDavid Cournapeau <cournape@gmail.com>2009-07-27 08:06:48 +0000
commitf49d6d36b666eaff6a82614badf452df1844e33b (patch)
treef8b060ad4d39c1565ad83878e3d8f599ee9f054a /numpy/testing/utils.py
parenta711484bf1bda459b67aad86c11556bdfc69e279 (diff)
downloadnumpy-f49d6d36b666eaff6a82614badf452df1844e33b.tar.gz
Handle nan and inf in assert_almost_equal.
Diffstat (limited to 'numpy/testing/utils.py')
-rw-r--r--numpy/testing/utils.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index ba9b16b18..b18cef6e3 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -6,6 +6,7 @@ import os
import sys
import re
import operator
+import types
from nosetester import import_nose
__all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal',
@@ -22,6 +23,39 @@ def assert_(val, msg='') :
if not val :
raise AssertionError(msg)
+def gisnan(x):
+ """like isnan, but always raise an error if type not supported instead of
+ returning a TypeError object.
+
+ Notes
+ -----
+ isnan and other ufunc sometimes return a NotImplementedType object instead
+ of raising any exception. This function is a wrapper to make sure an
+ exception is always raised.
+
+ This should be removed once this problem is solved at the Ufunc level."""
+ from numpy.core import isnan
+ st = isnan(x)
+ if isinstance(st, types.NotImplementedType):
+ raise TypeError("isnan not supported for this type")
+ return st
+
+def gisfinite(x):
+ """like isfinite, but always raise an error if type not supported instead of
+ returning a TypeError object.
+
+ Notes
+ -----
+ isfinite and other ufunc sometimes return a NotImplementedType object instead
+ of raising any exception. This function is a wrapper to make sure an
+ exception is always raised.
+
+ This should be removed once this problem is solved at the Ufunc level."""
+ from numpy.core import isfinite
+ st = isfinite(x)
+ if isinstance(st, types.NotImplementedType):
+ raise TypeError("isfinite not supported for this type")
+ return st
def rand(*args):
"""Returns an array of random numbers with the given shape.
@@ -261,6 +295,21 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
return assert_array_almost_equal(actual, desired, decimal, err_msg)
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
+
+ try:
+ # If one of desired/actual is not finite, handle it specially here:
+ # check that both are nan if any is a nan, and test for equality
+ # otherwise
+ if not (gisfinite(desired) and gisfinite(actual)):
+ if gisnan(desired) or gisnan(actual):
+ if not (gisnan(desired) and gisnan(actual)):
+ raise AssertionError(msg)
+ else:
+ if not desired == actual:
+ raise AssertionError(msg)
+ return
+ except TypeError:
+ pass
if round(abs(desired - actual),decimal) != 0 :
raise AssertionError(msg)