diff options
Diffstat (limited to 'numpy/testing/numpytest.py')
-rw-r--r-- | numpy/testing/numpytest.py | 78 |
1 files changed, 54 insertions, 24 deletions
diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py index 127a5d3e2..5d6df6f51 100644 --- a/numpy/testing/numpytest.py +++ b/numpy/testing/numpytest.py @@ -3,6 +3,7 @@ import os import re import sys import imp +import glob import types import unittest import traceback @@ -50,7 +51,7 @@ def set_package_path(level=1): if not os.path.isdir(d1): d1 = os.path.dirname(d) if DEBUG: - print 'Inserting %r to sys.path' % (d1) + print 'Inserting %r to sys.path for test_file %r' % (d1, testfile) sys.path.insert(0,d1) return @@ -156,7 +157,8 @@ class NumpyTestCase (unittest.TestCase): result.stream = save_stream def warn(self, message): - print>>sys.stderr,'Warning: %s' % (message) + from numpy.distutils.misc_util import yellow_text + print>>sys.stderr,yellow_text('Warning: %s' % (message)) sys.stderr.flush() def info(self, message): print>>sys.stdout, message @@ -224,10 +226,30 @@ class NumpyTest: """ _check_testcase_name = re.compile(r'test.*').match def check_testcase_name(self, name): - return self._check_testcase_name(name) is not None + """ Return True if name matches TestCase class. + """ + return not not self._check_testcase_name(name) - def get_testfile(self, short_module_name): - return 'test_' + short_module_name + '.py' + testfile_patterns = ['test_%(modulename)s.py'] + def get_testfile(self, module, verbosity = 0): + """ Return path to module test file. + """ + mstr = self._module_str + short_module_name = self._get_short_module_name(module) + d = os.path.split(module.__file__)[0] + test_dir = os.path.join(d,'tests') + local_test_dir = os.path.join(os.getcwd(),'tests') + if os.path.basename(os.path.dirname(local_test_dir)) \ + == os.path.basename(os.path.dirname(test_dir)): + test_dir = local_test_dir + for pat in self.testfile_patterns: + fn = os.path.join(test_dir, pat % {'modulename':short_module_name}) + if os.path.isfile(fn): + return fn + if verbosity>1: + self.warn('No test file found in %s for module %s' \ + % (test_dir, mstr(module))) + return def __init__(self, package=None): if package is None: @@ -276,29 +298,25 @@ class NumpyTest: names.append(n) return names - def _get_module_tests(self,module,level,verbosity): - mstr = self._module_str + def _get_short_module_name(self, module): d,f = os.path.split(module.__file__) - short_module_name = os.path.splitext(os.path.basename(f))[0] if short_module_name=='__init__': short_module_name = module.__name__.split('.')[-1] - short_module_name = self._rename_map.get(short_module_name,short_module_name) + return short_module_name + + def _get_module_tests(self,module,level,verbosity): + mstr = self._module_str + + short_module_name = self._get_short_module_name(module) if short_module_name is None: return [] - full_module_name = module.__name__+'.'+short_module_name - test_dir = os.path.join(d,'tests') - fn = self.get_testfile(short_module_name) - test_file = os.path.join(test_dir,fn) + test_file = self.get_testfile(module, verbosity) - local_test_dir = os.path.join(os.getcwd(),'tests') - local_test_file = os.path.join(local_test_dir, fn) - if os.path.basename(os.path.dirname(local_test_dir)) \ - == os.path.basename(os.path.dirname(test_dir)) \ - and os.path.isfile(local_test_file): - test_file = local_test_file + if test_file is None: + return [] if not os.path.isfile(test_file): if short_module_name[:5]=='info_' \ @@ -317,16 +335,27 @@ class NumpyTest: if test_file in self.test_files: return [] - + + parent_module_name = '.'.join(module.__name__.split('.')[:-1]) + test_module_name,ext = os.path.splitext(os.path.basename(test_file)) + test_dir_module = parent_module_name+'.tests' + test_module_name = test_dir_module+'.'+test_module_name + + if not sys.modules.has_key(test_dir_module): + sys.modules[test_dir_module] = imp.new_module(test_dir_module) + + old_sys_path = sys.path[:] try: f = open(test_file,'r') - test_module = imp.load_module(full_module_name, - f, test_file, ('.py', 'r', 1)) + test_module = imp.load_module(test_module_name, f, + test_file, ('.py', 'r', 1)) f.close() except: - self.warn(' !! FAILURE importing tests for %s' % mstr(module)) + sys.path[:] = old_sys_path + self.warn('FAILURE importing tests for %s' % (mstr(module))) output_exception(sys.stderr) return [] + sys.path[:] = old_sys_path self.test_files.append(test_file) @@ -420,7 +449,8 @@ class NumpyTest: return def warn(self, message): - print>>sys.stderr,'Warning: %s' % (message) + from numpy.distutils.misc_util import yellow_text + print>>sys.stderr,yellow_text('Warning: %s' % (message)) sys.stderr.flush() def info(self, message): print>>sys.stdout, message |