summaryrefslogtreecommitdiff
path: root/numpy/testing/numpytest.py
diff options
context:
space:
mode:
authorcookedm <cookedm@localhost>2006-03-07 22:02:23 +0000
committercookedm <cookedm@localhost>2006-03-07 22:02:23 +0000
commitc9d2cdc913171d079eabb6b71405d7101041356b (patch)
treea66f5fe7b0d197b434e83fe3e17dbd87076f3839 /numpy/testing/numpytest.py
parente3a1d502e5d08a755dd1d91eb74341c7617adbdd (diff)
parent5bb7342c6c2fa9757edc28df0dbbc8d433ac50d8 (diff)
downloadnumpy-c9d2cdc913171d079eabb6b71405d7101041356b.tar.gz
Merge trunk (r2142:2204) to power_optimization branch
Diffstat (limited to 'numpy/testing/numpytest.py')
-rw-r--r--numpy/testing/numpytest.py111
1 files changed, 79 insertions, 32 deletions
diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py
index 127a5d3e2..4b43c428e 100644
--- a/numpy/testing/numpytest.py
+++ b/numpy/testing/numpytest.py
@@ -3,6 +3,7 @@ import os
import re
import sys
import imp
+import glob
import types
import unittest
import traceback
@@ -50,7 +51,7 @@ def set_package_path(level=1):
if not os.path.isdir(d1):
d1 = os.path.dirname(d)
if DEBUG:
- print 'Inserting %r to sys.path' % (d1)
+ print 'Inserting %r to sys.path for test_file %r' % (d1, testfile)
sys.path.insert(0,d1)
return
@@ -156,7 +157,8 @@ class NumpyTestCase (unittest.TestCase):
result.stream = save_stream
def warn(self, message):
- print>>sys.stderr,'Warning: %s' % (message)
+ from numpy.distutils.misc_util import yellow_text
+ print>>sys.stderr,yellow_text('Warning: %s' % (message))
sys.stderr.flush()
def info(self, message):
print>>sys.stdout, message
@@ -224,10 +226,30 @@ class NumpyTest:
"""
_check_testcase_name = re.compile(r'test.*').match
def check_testcase_name(self, name):
- return self._check_testcase_name(name) is not None
+ """ Return True if name matches TestCase class.
+ """
+ return not not self._check_testcase_name(name)
- def get_testfile(self, short_module_name):
- return 'test_' + short_module_name + '.py'
+ testfile_patterns = ['test_%(modulename)s.py']
+ def get_testfile(self, module, verbosity = 0):
+ """ Return path to module test file.
+ """
+ mstr = self._module_str
+ short_module_name = self._get_short_module_name(module)
+ d = os.path.split(module.__file__)[0]
+ test_dir = os.path.join(d,'tests')
+ local_test_dir = os.path.join(os.getcwd(),'tests')
+ if os.path.basename(os.path.dirname(local_test_dir)) \
+ == os.path.basename(os.path.dirname(test_dir)):
+ test_dir = local_test_dir
+ for pat in self.testfile_patterns:
+ fn = os.path.join(test_dir, pat % {'modulename':short_module_name})
+ if os.path.isfile(fn):
+ return fn
+ if verbosity>1:
+ self.warn('No test file found in %s for module %s' \
+ % (test_dir, mstr(module)))
+ return
def __init__(self, package=None):
if package is None:
@@ -276,29 +298,25 @@ class NumpyTest:
names.append(n)
return names
- def _get_module_tests(self,module,level,verbosity):
- mstr = self._module_str
+ def _get_short_module_name(self, module):
d,f = os.path.split(module.__file__)
-
short_module_name = os.path.splitext(os.path.basename(f))[0]
if short_module_name=='__init__':
short_module_name = module.__name__.split('.')[-1]
-
short_module_name = self._rename_map.get(short_module_name,short_module_name)
+ return short_module_name
+
+ def _get_module_tests(self,module,level,verbosity):
+ mstr = self._module_str
+
+ short_module_name = self._get_short_module_name(module)
if short_module_name is None:
return []
- full_module_name = module.__name__+'.'+short_module_name
- test_dir = os.path.join(d,'tests')
- fn = self.get_testfile(short_module_name)
- test_file = os.path.join(test_dir,fn)
+ test_file = self.get_testfile(module, verbosity)
- local_test_dir = os.path.join(os.getcwd(),'tests')
- local_test_file = os.path.join(local_test_dir, fn)
- if os.path.basename(os.path.dirname(local_test_dir)) \
- == os.path.basename(os.path.dirname(test_dir)) \
- and os.path.isfile(local_test_file):
- test_file = local_test_file
+ if test_file is None:
+ return []
if not os.path.isfile(test_file):
if short_module_name[:5]=='info_' \
@@ -317,22 +335,34 @@ class NumpyTest:
if test_file in self.test_files:
return []
-
+
+ parent_module_name = '.'.join(module.__name__.split('.')[:-1])
+ test_module_name,ext = os.path.splitext(os.path.basename(test_file))
+ test_dir_module = parent_module_name+'.tests'
+ test_module_name = test_dir_module+'.'+test_module_name
+
+ if not sys.modules.has_key(test_dir_module):
+ sys.modules[test_dir_module] = imp.new_module(test_dir_module)
+
+ old_sys_path = sys.path[:]
try:
f = open(test_file,'r')
- test_module = imp.load_module(full_module_name,
- f, test_file, ('.py', 'r', 1))
+ test_module = imp.load_module(test_module_name, f,
+ test_file, ('.py', 'r', 1))
f.close()
except:
- self.warn(' !! FAILURE importing tests for %s' % mstr(module))
+ sys.path[:] = old_sys_path
+ self.warn('FAILURE importing tests for %s' % (mstr(module)))
output_exception(sys.stderr)
return []
+ sys.path[:] = old_sys_path
self.test_files.append(test_file)
return self._get_suite_list(test_module, level, module.__name__)
- def _get_suite_list(self, test_module, level, module_name='__main__'):
+ def _get_suite_list(self, test_module, level, module_name='__main__',
+ verbosity=1):
mstr = self._module_str
suite_list = []
if hasattr(test_module,'test_suite'):
@@ -347,13 +377,28 @@ class NumpyTest:
suite = obj(mthname)
if getattr(suite,'isrunnable',lambda mthname:1)(mthname):
suite_list.append(suite)
- self.info(' Found %s tests for %s' % (len(suite_list),module_name))
+ if verbosity>=0:
+ self.info(' Found %s tests for %s' % (len(suite_list),module_name))
return suite_list
def test(self,level=1,verbosity=1):
""" Run Numpy module test suite with level and verbosity.
+
+ level:
+ None --- do nothing, return None
+ < 0 --- scan for tests of level=abs(level),
+ don't run them, return TestSuite-list
+ > 0 --- scan for tests of level, run them,
+ return TestRunner
+
+ verbosity:
+ >= 0 --- show information messages
+ > 1 --- show warnings on missing tests
"""
- if type(self.package) is type(''):
+ if level is None: # Do nothing.
+ return
+
+ if isinstance(self.package, str):
exec 'import %s as this_package' % (self.package)
else:
this_package = self.package
@@ -374,14 +419,15 @@ class NumpyTest:
self.test_files = []
suites = []
for module in modules:
- suites.extend(self._get_module_tests(module, level, verbosity))
+ suites.extend(self._get_module_tests(module, abs(level), verbosity))
- suites.extend(self._get_suite_list(sys.modules[package_name], level))
+ suites.extend(self._get_suite_list(sys.modules[package_name],
+ abs(level), verbosity=verbosity))
all_tests = unittest.TestSuite(suites)
- #if hasattr(sys,'getobjects'):
- # runner = SciPyTextTestRunner(verbosity=verbosity)
- #else:
+ if level<0:
+ return all_tests
+
runner = unittest.TextTestRunner(verbosity=verbosity)
# Use the builtin displayhook. If the tests are being run
# under IPython (for instance), any doctest test suites will
@@ -420,7 +466,8 @@ class NumpyTest:
return
def warn(self, message):
- print>>sys.stderr,'Warning: %s' % (message)
+ from numpy.distutils.misc_util import yellow_text
+ print>>sys.stderr,yellow_text('Warning: %s' % (message))
sys.stderr.flush()
def info(self, message):
print>>sys.stdout, message