summaryrefslogtreecommitdiff
path: root/numpy/testing/numpytest.py
diff options
context:
space:
mode:
authorPearu Peterson <pearu.peterson@gmail.com>2006-02-23 12:37:17 +0000
committerPearu Peterson <pearu.peterson@gmail.com>2006-02-23 12:37:17 +0000
commit7adb2e2d0abc93ca9572ff451ff3ca1acdfce4cb (patch)
treea7700e5fea1364cea2564a02f4296032b8349913 /numpy/testing/numpytest.py
parent97342eff4ae87caeb83f4e419270b0da88a161e1 (diff)
downloadnumpy-7adb2e2d0abc93ca9572ff451ff3ca1acdfce4cb.tar.gz
Fixed a bug in importing zzz/tests/test_zzz.py when zzz/zzz.py exists. Added check_testcase_name and testfile_patterns to NumpyTest so that different test suite conventions can be used.
Diffstat (limited to 'numpy/testing/numpytest.py')
-rw-r--r--numpy/testing/numpytest.py78
1 files changed, 54 insertions, 24 deletions
diff --git a/numpy/testing/numpytest.py b/numpy/testing/numpytest.py
index 127a5d3e2..5d6df6f51 100644
--- a/numpy/testing/numpytest.py
+++ b/numpy/testing/numpytest.py
@@ -3,6 +3,7 @@ import os
import re
import sys
import imp
+import glob
import types
import unittest
import traceback
@@ -50,7 +51,7 @@ def set_package_path(level=1):
if not os.path.isdir(d1):
d1 = os.path.dirname(d)
if DEBUG:
- print 'Inserting %r to sys.path' % (d1)
+ print 'Inserting %r to sys.path for test_file %r' % (d1, testfile)
sys.path.insert(0,d1)
return
@@ -156,7 +157,8 @@ class NumpyTestCase (unittest.TestCase):
result.stream = save_stream
def warn(self, message):
- print>>sys.stderr,'Warning: %s' % (message)
+ from numpy.distutils.misc_util import yellow_text
+ print>>sys.stderr,yellow_text('Warning: %s' % (message))
sys.stderr.flush()
def info(self, message):
print>>sys.stdout, message
@@ -224,10 +226,30 @@ class NumpyTest:
"""
_check_testcase_name = re.compile(r'test.*').match
def check_testcase_name(self, name):
- return self._check_testcase_name(name) is not None
+ """ Return True if name matches TestCase class.
+ """
+ return not not self._check_testcase_name(name)
- def get_testfile(self, short_module_name):
- return 'test_' + short_module_name + '.py'
+ testfile_patterns = ['test_%(modulename)s.py']
+ def get_testfile(self, module, verbosity = 0):
+ """ Return path to module test file.
+ """
+ mstr = self._module_str
+ short_module_name = self._get_short_module_name(module)
+ d = os.path.split(module.__file__)[0]
+ test_dir = os.path.join(d,'tests')
+ local_test_dir = os.path.join(os.getcwd(),'tests')
+ if os.path.basename(os.path.dirname(local_test_dir)) \
+ == os.path.basename(os.path.dirname(test_dir)):
+ test_dir = local_test_dir
+ for pat in self.testfile_patterns:
+ fn = os.path.join(test_dir, pat % {'modulename':short_module_name})
+ if os.path.isfile(fn):
+ return fn
+ if verbosity>1:
+ self.warn('No test file found in %s for module %s' \
+ % (test_dir, mstr(module)))
+ return
def __init__(self, package=None):
if package is None:
@@ -276,29 +298,25 @@ class NumpyTest:
names.append(n)
return names
- def _get_module_tests(self,module,level,verbosity):
- mstr = self._module_str
+ def _get_short_module_name(self, module):
d,f = os.path.split(module.__file__)
-
short_module_name = os.path.splitext(os.path.basename(f))[0]
if short_module_name=='__init__':
short_module_name = module.__name__.split('.')[-1]
-
short_module_name = self._rename_map.get(short_module_name,short_module_name)
+ return short_module_name
+
+ def _get_module_tests(self,module,level,verbosity):
+ mstr = self._module_str
+
+ short_module_name = self._get_short_module_name(module)
if short_module_name is None:
return []
- full_module_name = module.__name__+'.'+short_module_name
- test_dir = os.path.join(d,'tests')
- fn = self.get_testfile(short_module_name)
- test_file = os.path.join(test_dir,fn)
+ test_file = self.get_testfile(module, verbosity)
- local_test_dir = os.path.join(os.getcwd(),'tests')
- local_test_file = os.path.join(local_test_dir, fn)
- if os.path.basename(os.path.dirname(local_test_dir)) \
- == os.path.basename(os.path.dirname(test_dir)) \
- and os.path.isfile(local_test_file):
- test_file = local_test_file
+ if test_file is None:
+ return []
if not os.path.isfile(test_file):
if short_module_name[:5]=='info_' \
@@ -317,16 +335,27 @@ class NumpyTest:
if test_file in self.test_files:
return []
-
+
+ parent_module_name = '.'.join(module.__name__.split('.')[:-1])
+ test_module_name,ext = os.path.splitext(os.path.basename(test_file))
+ test_dir_module = parent_module_name+'.tests'
+ test_module_name = test_dir_module+'.'+test_module_name
+
+ if not sys.modules.has_key(test_dir_module):
+ sys.modules[test_dir_module] = imp.new_module(test_dir_module)
+
+ old_sys_path = sys.path[:]
try:
f = open(test_file,'r')
- test_module = imp.load_module(full_module_name,
- f, test_file, ('.py', 'r', 1))
+ test_module = imp.load_module(test_module_name, f,
+ test_file, ('.py', 'r', 1))
f.close()
except:
- self.warn(' !! FAILURE importing tests for %s' % mstr(module))
+ sys.path[:] = old_sys_path
+ self.warn('FAILURE importing tests for %s' % (mstr(module)))
output_exception(sys.stderr)
return []
+ sys.path[:] = old_sys_path
self.test_files.append(test_file)
@@ -420,7 +449,8 @@ class NumpyTest:
return
def warn(self, message):
- print>>sys.stderr,'Warning: %s' % (message)
+ from numpy.distutils.misc_util import yellow_text
+ print>>sys.stderr,yellow_text('Warning: %s' % (message))
sys.stderr.flush()
def info(self, message):
print>>sys.stdout, message