diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/__init__.py | 4 | ||||
-rw-r--r-- | numpy/testing/info.py | 30 | ||||
-rw-r--r-- | numpy/testing/numpytest.py | 661 | ||||
-rwxr-xr-x | numpy/testing/setup.py | 16 | ||||
-rw-r--r-- | numpy/testing/utils.py | 238 |
5 files changed, 949 insertions, 0 deletions
diff --git a/numpy/testing/__init__.py b/numpy/testing/__init__.py new file mode 100644 index 000000000..028890a49 --- /dev/null +++ b/numpy/testing/__init__.py @@ -0,0 +1,4 @@ + +from info import __doc__ +from numpytest import * +from utils import * diff --git a/numpy/testing/info.py b/numpy/testing/info.py new file mode 100644 index 000000000..8b09d8ed3 --- /dev/null +++ b/numpy/testing/info.py @@ -0,0 +1,30 @@ +""" +Numpy testing tools +=================== + +Numpy-style unit-testing +------------------------ + + NumpyTest -- Numpy tests site manager + NumpyTestCase -- unittest.TestCase with measure method + IgnoreException -- raise when checking disabled feature, it'll be ignored + set_package_path -- prepend package build directory to path + set_local_path -- prepend local directory (to tests files) to path + restore_path -- restore path after set_package_path + +Utility functions +----------------- + + jiffies -- return 1/100ths of a second that the current process has used + memusage -- virtual memory size in bytes of the running python [linux] + rand -- array of random numbers from given shape + assert_equal -- assert equality + assert_almost_equal -- assert equality with decimal tolerance + assert_approx_equal -- assert equality with significant digits tolerance + assert_array_equal -- assert arrays equality + assert_array_almost_equal -- assert arrays equality with decimal tolerance + assert_array_less -- assert arrays less-ordering + +""" + +global_symbols = ['ScipyTest','NumpyTest'] diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py new file mode 100644 index 000000000..da09a830d --- /dev/null +++ b/numpy/testing/numpytest.py @@ -0,0 +1,661 @@ +import os +import re +import sys +import imp +import glob +import types +import unittest +import traceback +import warnings + +__all__ = ['set_package_path', 'set_local_path', 'restore_path', + 'IgnoreException', 'NumpyTestCase', 'NumpyTest', + 'ScipyTestCase', 'ScipyTest', # for backward compatibility + 'importall', + ] + +DEBUG=0 +from numpy.testing.utils import jiffies +get_frame = sys._getframe + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +def set_package_path(level=1): + """ Prepend package directory to sys.path. + + set_package_path should be called from a test_file.py that + satisfies the following tree structure: + + <somepath>/<somedir>/test_file.py + + Then the first existing path name from the following list + + <somepath>/build/lib.<platform>-<version> + <somepath>/.. + + is prepended to sys.path. + The caller is responsible for removing this path by using + + restore_path() + """ + from distutils.util import get_platform + f = get_frame(level) + if f.f_locals['__name__']=='__main__': + testfile = sys.argv[0] + else: + testfile = f.f_locals['__file__'] + d = os.path.dirname(os.path.dirname(os.path.abspath(testfile))) + d1 = os.path.join(d,'build','lib.%s-%s'%(get_platform(),sys.version[:3])) + if not os.path.isdir(d1): + d1 = os.path.dirname(d) + if DEBUG: + print 'Inserting %r to sys.path for test_file %r' % (d1, testfile) + sys.path.insert(0,d1) + return + + +def set_local_path(reldir='', level=1): + """ Prepend local directory to sys.path. + + The caller is responsible for removing this path by using + + restore_path() + """ + f = get_frame(level) + if f.f_locals['__name__']=='__main__': + testfile = sys.argv[0] + else: + testfile = f.f_locals['__file__'] + local_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(testfile)),reldir)) + if DEBUG: + print 'Inserting %r to sys.path' % (local_path) + sys.path.insert(0,local_path) + return + + +def restore_path(): + if DEBUG: + print 'Removing %r from sys.path' % (sys.path[0]) + del sys.path[0] + return + + +def output_exception(printstream = sys.stdout): + try: + type, value, tb = sys.exc_info() + info = traceback.extract_tb(tb) + #this is more verbose + #traceback.print_exc() + filename, lineno, function, text = info[-1] # last line only + print>>printstream, "%s:%d: %s: %s (in %s)" %\ + (filename, lineno, type.__name__, str(value), function) + finally: + type = value = tb = None # clean up + return + + +class _dummy_stream: + def __init__(self,stream): + self.data = [] + self.stream = stream + def write(self,message): + if not self.data and not message.startswith('E'): + self.stream.write(message) + self.stream.flush() + message = '' + self.data.append(message) + def writeln(self,message): + self.write(message+'\n') + + +class NumpyTestCase (unittest.TestCase): + + def measure(self,code_str,times=1): + """ Return elapsed time for executing code_str in the + namespace of the caller for given times. + """ + frame = get_frame(1) + locs,globs = frame.f_locals,frame.f_globals + code = compile(code_str, + 'NumpyTestCase runner for '+self.__class__.__name__, + 'exec') + i = 0 + elapsed = jiffies() + while i<times: + i += 1 + exec code in globs,locs + elapsed = jiffies() - elapsed + return 0.01*elapsed + + def __call__(self, result=None): + if result is None: + return unittest.TestCase.__call__(self, result) + + nof_errors = len(result.errors) + save_stream = result.stream + result.stream = _dummy_stream(save_stream) + unittest.TestCase.__call__(self, result) + if nof_errors != len(result.errors): + test, errstr = result.errors[-1][:2] + if isinstance(errstr, tuple): + errstr = str(errstr[0]) + elif isinstance(errstr, str): + errstr = errstr.split('\n')[-2] + else: + # allow for proxy classes + errstr = str(errstr).split('\n')[-2] + l = len(result.stream.data) + if errstr.startswith('IgnoreException:'): + if l==1: + assert result.stream.data[-1]=='E', \ + repr(result.stream.data) + result.stream.data[-1] = 'i' + else: + assert result.stream.data[-1]=='ERROR\n', \ + repr(result.stream.data) + result.stream.data[-1] = 'ignoring\n' + del result.errors[-1] + map(save_stream.write, result.stream.data) + save_stream.flush() + result.stream = save_stream + + def warn(self, message): + from numpy.distutils.misc_util import yellow_text + print>>sys.stderr,yellow_text('Warning: %s' % (message)) + sys.stderr.flush() + def info(self, message): + print>>sys.stdout, message + sys.stdout.flush() + + def rundocs(self, filename=None): + """ Run doc string tests found in filename. + """ + import doctest + if filename is None: + f = get_frame(1) + filename = f.f_globals['__file__'] + name = os.path.splitext(os.path.basename(filename))[0] + path = [os.path.dirname(filename)] + file, pathname, description = imp.find_module(name, path) + try: + m = imp.load_module(name, file, pathname, description) + finally: + file.close() + if sys.version[:3]<'2.4': + doctest.testmod(m, verbose=False) + else: + tests = doctest.DocTestFinder().find(m) + runner = doctest.DocTestRunner(verbose=False) + for test in tests: + runner.run(test) + return + +class ScipyTestCase(NumpyTestCase): + def __init__(self, package=None): + warnings.warn("ScipyTestCase is now called NumpyTestCase; please update your code", + DeprecationWarning, stacklevel=2) + NumpyTestCase.__init__(self, package) + + +def _get_all_method_names(cls): + names = dir(cls) + if sys.version[:3]<='2.1': + for b in cls.__bases__: + for n in dir(b)+_get_all_method_names(b): + if n not in names: + names.append(n) + return names + + +# for debug build--check for memory leaks during the test. +class _NumPyTextTestResult(unittest._TextTestResult): + def startTest(self, test): + unittest._TextTestResult.startTest(self, test) + if self.showAll: + N = len(sys.getobjects(0)) + self._totnumobj = N + self._totrefcnt = sys.gettotalrefcount() + return + + def stopTest(self, test): + if self.showAll: + N = len(sys.getobjects(0)) + self.stream.write("objects: %d ===> %d; " % (self._totnumobj, N)) + self.stream.write("refcnts: %d ===> %d\n" % (self._totrefcnt, + sys.gettotalrefcount())) + return + +class NumPyTextTestRunner(unittest.TextTestRunner): + def _makeResult(self): + return _NumPyTextTestResult(self.stream, self.descriptions, self.verbosity) + + +class NumpyTest: + """ Numpy tests site manager. + + Usage: NumpyTest(<package>).test(level=1,verbosity=1) + + <package> is package name or its module object. + + Package is supposed to contain a directory tests/ with test_*.py + files where * refers to the names of submodules. See .rename() + method to redefine name mapping between test_*.py files and names of + submodules. Pattern test_*.py can be overwritten by redefining + .get_testfile() method. + + test_*.py files are supposed to define a classes, derived from + NumpyTestCase or unittest.TestCase, with methods having names + starting with test or bench or check. The names of TestCase classes + must have a prefix test. This can be overwritten by redefining + .check_testcase_name() method. + + And that is it! No need to implement test or test_suite functions + in each .py file. + + Old-style test_suite(level=1) hooks are also supported. + """ + _check_testcase_name = re.compile(r'test.*').match + def check_testcase_name(self, name): + """ Return True if name matches TestCase class. + """ + return not not self._check_testcase_name(name) + + testfile_patterns = ['test_%(modulename)s.py'] + def get_testfile(self, module, verbosity = 0): + """ Return path to module test file. + """ + mstr = self._module_str + short_module_name = self._get_short_module_name(module) + d = os.path.split(module.__file__)[0] + test_dir = os.path.join(d,'tests') + local_test_dir = os.path.join(os.getcwd(),'tests') + if os.path.basename(os.path.dirname(local_test_dir)) \ + == os.path.basename(os.path.dirname(test_dir)): + test_dir = local_test_dir + for pat in self.testfile_patterns: + fn = os.path.join(test_dir, pat % {'modulename':short_module_name}) + if os.path.isfile(fn): + return fn + if verbosity>1: + self.warn('No test file found in %s for module %s' \ + % (test_dir, mstr(module))) + return + + def __init__(self, package=None): + if package is None: + from numpy.distutils.misc_util import get_frame + f = get_frame(1) + package = f.f_locals.get('__name__',f.f_globals.get('__name__',None)) + assert package is not None + self.package = package + self._rename_map = {} + + def rename(self, **kws): + """Apply renaming submodule test file test_<name>.py to + test_<newname>.py. + + Usage: self.rename(name='newname') before calling the + self.test() method. + + If 'newname' is None, then no tests will be executed for a given + module. + """ + for k,v in kws.items(): + self._rename_map[k] = v + return + + def _module_str(self, module): + filename = module.__file__[-30:] + if filename!=module.__file__: + filename = '...'+filename + return '<module %r from %r>' % (module.__name__, filename) + + def _get_method_names(self,clsobj,level): + names = [] + for mthname in _get_all_method_names(clsobj): + if mthname[:5] not in ['bench','check'] \ + and mthname[:4] not in ['test']: + continue + mth = getattr(clsobj, mthname) + if type(mth) is not types.MethodType: + continue + d = mth.im_func.func_defaults + if d is not None: + mthlevel = d[0] + else: + mthlevel = 1 + if level>=mthlevel: + if mthname not in names: + names.append(mthname) + for base in clsobj.__bases__: + for n in self._get_method_names(base,level): + if n not in names: + names.append(n) + return names + + def _get_short_module_name(self, module): + d,f = os.path.split(module.__file__) + short_module_name = os.path.splitext(os.path.basename(f))[0] + if short_module_name=='__init__': + short_module_name = module.__name__.split('.')[-1] + short_module_name = self._rename_map.get(short_module_name,short_module_name) + return short_module_name + + def _get_module_tests(self, module, level, verbosity): + mstr = self._module_str + + short_module_name = self._get_short_module_name(module) + if short_module_name is None: + return [] + + test_file = self.get_testfile(module, verbosity) + + if test_file is None: + return [] + + if not os.path.isfile(test_file): + if short_module_name[:5]=='info_' \ + and short_module_name[5:]==module.__name__.split('.')[-2]: + return [] + if short_module_name in ['__cvs_version__','__svn_version__']: + return [] + if short_module_name[-8:]=='_version' \ + and short_module_name[:-8]==module.__name__.split('.')[-2]: + return [] + if verbosity>1: + self.warn(test_file) + self.warn(' !! No test file %r found for %s' \ + % (os.path.basename(test_file), mstr(module))) + return [] + + if test_file in self.test_files: + return [] + + parent_module_name = '.'.join(module.__name__.split('.')[:-1]) + test_module_name,ext = os.path.splitext(os.path.basename(test_file)) + test_dir_module = parent_module_name+'.tests' + test_module_name = test_dir_module+'.'+test_module_name + + if not sys.modules.has_key(test_dir_module): + sys.modules[test_dir_module] = imp.new_module(test_dir_module) + + old_sys_path = sys.path[:] + try: + f = open(test_file,'r') + test_module = imp.load_module(test_module_name, f, + test_file, ('.py', 'r', 1)) + f.close() + except: + sys.path[:] = old_sys_path + self.warn('FAILURE importing tests for %s' % (mstr(module))) + output_exception(sys.stderr) + return [] + sys.path[:] = old_sys_path + + self.test_files.append(test_file) + + return self._get_suite_list(test_module, level, module.__name__) + + def _get_suite_list(self, test_module, level, module_name='__main__', + verbosity=1): + suite_list = [] + if hasattr(test_module, 'test_suite'): + suite_list.extend(test_module.test_suite(level)._tests) + for name in dir(test_module): + obj = getattr(test_module, name) + if type(obj) is not type(unittest.TestCase) \ + or not issubclass(obj, unittest.TestCase) \ + or not self.check_testcase_name(obj.__name__): + continue + for mthname in self._get_method_names(obj,level): + suite = obj(mthname) + if getattr(suite,'isrunnable',lambda mthname:1)(mthname): + suite_list.append(suite) + if verbosity>=0: + self.info(' Found %s tests for %s' % (len(suite_list), module_name)) + return suite_list + + def _test_suite_from_modules(self, this_package, level, verbosity): + package_name = this_package.__name__ + modules = [] + for name, module in sys.modules.items(): + if not name.startswith(package_name) or module is None: + continue + if not hasattr(module,'__file__'): + continue + if os.path.basename(os.path.dirname(module.__file__))=='tests': + continue + modules.append((name, module)) + + modules.sort() + modules = [m[1] for m in modules] + + self.test_files = [] + suites = [] + for module in modules: + suites.extend(self._get_module_tests(module, abs(level), verbosity)) + + suites.extend(self._get_suite_list(sys.modules[package_name], + abs(level), verbosity=verbosity)) + return unittest.TestSuite(suites) + + def _test_suite_from_all_tests(self, this_package, level, verbosity): + importall(this_package) + package_name = this_package.__name__ + + # Find all tests/ directories under the package + test_dirs_names = {} + for name, module in sys.modules.items(): + if not name.startswith(package_name) or module is None: + continue + if not hasattr(module, '__file__'): + continue + d = os.path.dirname(module.__file__) + if os.path.basename(d)=='tests': + continue + d = os.path.join(d, 'tests') + if not os.path.isdir(d): + continue + if test_dirs_names.has_key(d): continue + test_dir_module = '.'.join(name.split('.')[:-1]+['tests']) + test_dirs_names[d] = test_dir_module + + test_dirs = test_dirs_names.keys() + test_dirs.sort() + + # For each file in each tests/ directory with a test case in it, + # import the file, and add the test cases to our list + suite_list = [] + testcase_match = re.compile(r'\s*class\s+\w+\s*\(.*TestCase').match + for test_dir in test_dirs: + test_dir_module = test_dirs_names[test_dir] + + if not sys.modules.has_key(test_dir_module): + sys.modules[test_dir_module] = imp.new_module(test_dir_module) + + for fn in os.listdir(test_dir): + base, ext = os.path.splitext(fn) + if ext != '.py': + continue + f = os.path.join(test_dir, fn) + + # check that file contains TestCase class definitions: + fid = open(f, 'r') + skip = True + for line in fid: + if testcase_match(line): + skip = False + break + fid.close() + if skip: + continue + + # import the test file + n = test_dir_module + '.' + base + # in case test files import local modules + sys.path.insert(0, test_dir) + fo = None + try: + try: + fo = open(f) + test_module = imp.load_module(n, fo, f, + ('.py', 'U', 1)) + except Exception, msg: + print 'Failed importing %s: %s' % (f,msg) + continue + finally: + if fo: + fo.close() + del sys.path[0] + + suites = self._get_suite_list(test_module, level, + module_name=n, + verbosity=verbosity) + suite_list.extend(suites) + + all_tests = unittest.TestSuite(suite_list) + return all_tests + + def test(self, level=1, verbosity=1, all=False): + """Run Numpy module test suite with level and verbosity. + + level: + None --- do nothing, return None + < 0 --- scan for tests of level=abs(level), + don't run them, return TestSuite-list + > 0 --- scan for tests of level, run them, + return TestRunner + > 10 --- run all tests (same as specifying all=True). + (backward compatibility). + + verbosity: + >= 0 --- show information messages + > 1 --- show warnings on missing tests + + all: + True --- run all test files (like self.testall()) + False (default) --- only run test files associated with a module + + It is assumed (when all=False) that package tests suite follows + the following convention: for each package module, there exists + file <packagepath>/tests/test_<modulename>.py that defines + TestCase classes (with names having prefix 'test_') with methods + (with names having prefixes 'check_' or 'bench_'); each of these + methods are called when running unit tests. + """ + if level is None: # Do nothing. + return + + if isinstance(self.package, str): + exec 'import %s as this_package' % (self.package) + else: + this_package = self.package + + if all: + all_tests = self._test_suite_from_all_tests(this_package, + level, verbosity) + else: + all_tests = self._test_suite_from_modules(this_package, + level, verbosity) + + if level < 0: + return all_tests + + runner = unittest.TextTestRunner(verbosity=verbosity) + # Use the builtin displayhook. If the tests are being run + # under IPython (for instance), any doctest test suites will + # fail otherwise. + old_displayhook = sys.displayhook + sys.displayhook = sys.__displayhook__ + try: + runner.run(all_tests) + finally: + sys.displayhook = old_displayhook + return runner + + def testall(self, level=1,verbosity=1): + """ Run Numpy module test suite with level and verbosity. + + level: + None --- do nothing, return None + < 0 --- scan for tests of level=abs(level), + don't run them, return TestSuite-list + > 0 --- scan for tests of level, run them, + return TestRunner + + verbosity: + >= 0 --- show information messages + > 1 --- show warnings on missing tests + + Different from .test(..) method, this method looks for + TestCase classes from all files in <packagedir>/tests/ + directory and no assumptions are made for naming the + TestCase classes or their methods. + """ + return self.test(level=level, verbosity=verbosity, all=True) + + def run(self): + """ Run Numpy module test suite with level and verbosity + taken from sys.argv. Requires optparse module. + """ + try: + from optparse import OptionParser + except ImportError: + self.warn('Failed to import optparse module, ignoring.') + return self.test() + usage = r'usage: %prog [-v <verbosity>] [-l <level>]' + parser = OptionParser(usage) + parser.add_option("-v", "--verbosity", + action="store", + dest="verbosity", + default=1, + type='int') + parser.add_option("-l", "--level", + action="store", + dest="level", + default=1, + type='int') + (options, args) = parser.parse_args() + self.test(options.level,options.verbosity) + return + + def warn(self, message): + from numpy.distutils.misc_util import yellow_text + print>>sys.stderr,yellow_text('Warning: %s' % (message)) + sys.stderr.flush() + def info(self, message): + print>>sys.stdout, message + sys.stdout.flush() + +class ScipyTest(NumpyTest): + def __init__(self, package=None): + warnings.warn("ScipyTest is now called NumpyTest; please update your code", + DeprecationWarning, stacklevel=2) + NumpyTest.__init__(self, package) + + +def importall(package): + """ + Try recursively to import all subpackages under package. + """ + if isinstance(package,str): + package = __import__(package) + + package_name = package.__name__ + package_dir = os.path.dirname(package.__file__) + for subpackage_name in os.listdir(package_dir): + subdir = os.path.join(package_dir, subpackage_name) + if not os.path.isdir(subdir): + continue + if not os.path.isfile(os.path.join(subdir,'__init__.py')): + continue + name = package_name+'.'+subpackage_name + try: + exec 'import %s as m' % (name) + except Exception, msg: + print 'Failed importing %s: %s' %(name, msg) + continue + importall(m) + return diff --git a/numpy/testing/setup.py b/numpy/testing/setup.py new file mode 100755 index 000000000..ad248d27f --- /dev/null +++ b/numpy/testing/setup.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +def configuration(parent_package='',top_path=None): + from numpy.distutils.misc_util import Configuration + config = Configuration('testing',parent_package,top_path) + return config + +if __name__ == '__main__': + from numpy.distutils.core import setup + setup(maintainer = "NumPy Developers", + maintainer_email = "numpy-dev@numpy.org", + description = "NumPy test module", + url = "http://www.numpy.org", + license = "NumPy License (BSD Style)", + configuration = configuration, + ) diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py new file mode 100644 index 000000000..8e01afb56 --- /dev/null +++ b/numpy/testing/utils.py @@ -0,0 +1,238 @@ +""" +Utility function to facilitate testing. +""" + +import os +import sys +import operator + +__all__ = ['assert_equal', 'assert_almost_equal','assert_approx_equal', + 'assert_array_equal', 'assert_array_less', + 'assert_array_almost_equal', 'jiffies', 'memusage', 'rand', + 'runstring'] + +def rand(*args): + """Returns an array of random numbers with the given shape. + + This only uses the standard library, so it is useful for testing purposes. + """ + import random + from numpy.core import zeros, float64 + results = zeros(args, float64) + f = results.flat + for i in range(len(f)): + f[i] = random.random() + return results + +if sys.platform[:5]=='linux': + def jiffies(_proc_pid_stat = '/proc/%s/stat'%(os.getpid()), + _load_time=[]): + """ Return number of jiffies (1/100ths of a second) that this + process has been scheduled in user mode. See man 5 proc. """ + import time + if not _load_time: + _load_time.append(time.time()) + try: + f=open(_proc_pid_stat,'r') + l = f.readline().split(' ') + f.close() + return int(l[13]) + except: + return int(100*(time.time()-_load_time[0])) + + def memusage(_proc_pid_stat = '/proc/%s/stat'%(os.getpid())): + """ Return virtual memory size in bytes of the running python. + """ + try: + f=open(_proc_pid_stat,'r') + l = f.readline().split(' ') + f.close() + return int(l[22]) + except: + return +else: + # os.getpid is not in all platforms available. + # Using time is safe but inaccurate, especially when process + # was suspended or sleeping. + def jiffies(_load_time=[]): + """ Return number of jiffies (1/100ths of a second) that this + process has been scheduled in user mode. [Emulation with time.time]. """ + import time + if not _load_time: + _load_time.append(time.time()) + return int(100*(time.time()-_load_time[0])) + def memusage(): + """ Return memory usage of running python. [Not implemented]""" + raise NotImplementedError + +if os.name=='nt' and sys.version[:3] > '2.3': + # Code "stolen" from enthought/debug/memusage.py + def GetPerformanceAttributes(object, counter, instance = None, + inum=-1, format = None, machine=None): + # NOTE: Many counters require 2 samples to give accurate results, + # including "% Processor Time" (as by definition, at any instant, a + # thread's CPU usage is either 0 or 100). To read counters like this, + # you should copy this function, but keep the counter open, and call + # CollectQueryData() each time you need to know. + # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp + # My older explanation for this was that the "AddCounter" process forced + # the CPU to 100%, but the above makes more sense :) + import win32pdh + if format is None: format = win32pdh.PDH_FMT_LONG + path = win32pdh.MakeCounterPath( (machine,object,instance, None, inum,counter) ) + hq = win32pdh.OpenQuery() + try: + hc = win32pdh.AddCounter(hq, path) + try: + win32pdh.CollectQueryData(hq) + type, val = win32pdh.GetFormattedCounterValue(hc, format) + return val + finally: + win32pdh.RemoveCounter(hc) + finally: + win32pdh.CloseQuery(hq) + + def memusage(processName="python", instance=0): + # from win32pdhutil, part of the win32all package + import win32pdh + return GetPerformanceAttributes("Process", "Virtual Bytes", + processName, instance, + win32pdh.PDH_FMT_LONG, None) + +def build_err_msg(arrays, err_msg, header='Items are not equal:', + verbose=True, + names=('ACTUAL', 'DESIRED')): + msg = ['\n' + header] + if err_msg: + if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): + msg = [msg[0] + ' ' + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + try: + r = repr(a) + except: + r = '[repr failed]' + if r.count('\n') > 3: + r = '\n'.join(r.splitlines()[:3]) + r += '...' + msg.append(' %s: %s' % (names[i], r)) + return '\n'.join(msg) + +def assert_equal(actual,desired,err_msg='',verbose=True): + """ Raise an assertion if two items are not + equal. I think this should be part of unittest.py + """ + if isinstance(desired, dict): + assert isinstance(actual, dict), repr(type(actual)) + assert_equal(len(actual),len(desired),err_msg,verbose) + for k,i in desired.items(): + assert actual.has_key(k), repr(k) + assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg), verbose) + return + if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)): + assert_equal(len(actual),len(desired),err_msg,verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg), verbose) + return + from numpy.core import ndarray + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + assert desired == actual, msg + +def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): + """ Raise an assertion if two items are not equal. + + I think this should be part of unittest.py + + The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal) + """ + from numpy.core import ndarray + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + assert round(abs(desired - actual),decimal) == 0, msg + + +def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): + """ Raise an assertion if two items are not + equal. I think this should be part of unittest.py + Approximately equal is defined as the number of significant digits + correct + """ + import math + actual, desired = map(float, (actual, desired)) + if desired==actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + try: + sc_desired = desired/scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual/scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg([actual, desired], err_msg, + header='Items are not equal to %d significant digits:' % + significant, + verbose=verbose) + assert math.fabs(sc_desired - sc_actual) < pow(10.,-(significant-1)), msg + +def assert_array_compare(comparison, x, y, err_msg='', verbose=True, + header=''): + from numpy.core import asarray + x = asarray(x) + y = asarray(y) + try: + cond = (x.shape==() or y.shape==()) or x.shape == y.shape + if not cond: + msg = build_err_msg([x, y], + err_msg + + '\n(shapes %s, %s mismatch)' % (x.shape, + y.shape), + verbose=verbose, header=header, + names=('x', 'y')) + assert cond, msg + val = comparison(x,y) + if isinstance(val, bool): + cond = val + reduced = [0] + else: + reduced = val.ravel() + cond = reduced.all() + reduced = reduced.tolist() + if not cond: + match = 100-100.0*reduced.count(1)/len(reduced) + msg = build_err_msg([x, y], + err_msg + + '\n(mismatch %s%%)' % (match,), + verbose=verbose, header=header, + names=('x', 'y')) + assert cond, msg + except ValueError: + msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, + names=('x', 'y')) + raise ValueError(msg) + +def assert_array_equal(x, y, err_msg='', verbose=True): + assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, + verbose=verbose, header='Arrays are not equal') + +def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): + from numpy.core import around + def compare(x, y): + return around(abs(x-y),decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, + header='Arrays are not almost equal') + +def assert_array_less(x, y, err_msg='', verbose=True): + assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, + verbose=verbose, + header='Arrays are not less-ordered') + +def runstring(astr, dict): + exec astr in dict |