summaryrefslogtreecommitdiff
path: root/scipy/weave/tests/test_c_spec.py
diff options
context:
space:
mode:
Diffstat (limited to 'scipy/weave/tests/test_c_spec.py')
-rw-r--r--scipy/weave/tests/test_c_spec.py678
1 files changed, 678 insertions, 0 deletions
diff --git a/scipy/weave/tests/test_c_spec.py b/scipy/weave/tests/test_c_spec.py
new file mode 100644
index 000000000..ce63595e2
--- /dev/null
+++ b/scipy/weave/tests/test_c_spec.py
@@ -0,0 +1,678 @@
+import unittest
+import time
+import os,sys
+
+# Note: test_dir is global to this file.
+# It is made by setup_test_location()
+
+#globals
+global test_dir
+test_dir = ''
+
+from scipy_test.testing import *
+set_package_path()
+from weave import inline_tools,ext_tools,c_spec
+from weave.build_tools import msvc_exists, gcc_exists
+from weave.catalog import unique_file
+restore_path()
+
+
+def unique_mod(d,file_name):
+ f = os.path.basename(unique_file(d,file_name))
+ m = os.path.splitext(f)[0]
+ return m
+
+def remove_whitespace(in_str):
+ import string
+ out = string.replace(in_str," ","")
+ out = string.replace(out,"\t","")
+ out = string.replace(out,"\n","")
+ return out
+
+def print_assert_equal(test_string,actual,desired):
+ """this should probably be in scipy_test.testing
+ """
+ import pprint
+ try:
+ assert(actual == desired)
+ except AssertionError:
+ import cStringIO
+ msg = cStringIO.StringIO()
+ msg.write(test_string)
+ msg.write(' failed\nACTUAL: \n')
+ pprint.pprint(actual,msg)
+ msg.write('DESIRED: \n')
+ pprint.pprint(desired,msg)
+ raise AssertionError, msg.getvalue()
+
+#----------------------------------------------------------------------------
+# Scalar conversion test classes
+# int, float, complex
+#----------------------------------------------------------------------------
+class test_int_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_string(self,level=5):
+ s = c_spec.int_converter()
+ assert( not s.type_match('string') )
+ def check_type_match_int(self,level=5):
+ s = c_spec.int_converter()
+ assert(s.type_match(5))
+ def check_type_match_float(self,level=5):
+ s = c_spec.int_converter()
+ assert(not s.type_match(5.))
+ def check_type_match_complex(self,level=5):
+ s = c_spec.int_converter()
+ assert(not s.type_match(5.+1j))
+ def check_var_in(self,level=5):
+ mod_name = 'int_var_in' + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1
+ code = "a=2;"
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'abc'
+ test(b)
+ except TypeError:
+ pass
+
+ def check_int_return(self,level=5):
+ mod_name = sys._getframe().f_code.co_name + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1
+ code = """
+ a=a+2;
+ return_val = PyInt_FromLong(a);
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1
+ c = test(b)
+
+ assert( c == 3)
+
+class test_float_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_string(self,level=5):
+ s = c_spec.float_converter()
+ assert( not s.type_match('string') )
+ def check_type_match_int(self,level=5):
+ s = c_spec.float_converter()
+ assert(not s.type_match(5))
+ def check_type_match_float(self,level=5):
+ s = c_spec.float_converter()
+ assert(s.type_match(5.))
+ def check_type_match_complex(self,level=5):
+ s = c_spec.float_converter()
+ assert(not s.type_match(5.+1j))
+ def check_float_var_in(self,level=5):
+ mod_name = sys._getframe().f_code.co_name + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1.
+ code = "a=2.;"
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1.
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'abc'
+ test(b)
+ except TypeError:
+ pass
+
+
+ def check_float_return(self,level=5):
+ mod_name = sys._getframe().f_code.co_name + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1.
+ code = """
+ a=a+2.;
+ return_val = PyFloat_FromDouble(a);
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1.
+ c = test(b)
+ assert( c == 3.)
+
+class test_complex_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_string(self,level=5):
+ s = c_spec.complex_converter()
+ assert( not s.type_match('string') )
+ def check_type_match_int(self,level=5):
+ s = c_spec.complex_converter()
+ assert(not s.type_match(5))
+ def check_type_match_float(self,level=5):
+ s = c_spec.complex_converter()
+ assert(not s.type_match(5.))
+ def check_type_match_complex(self,level=5):
+ s = c_spec.complex_converter()
+ assert(s.type_match(5.+1j))
+ def check_complex_var_in(self,level=5):
+ mod_name = sys._getframe().f_code.co_name + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1.+1j
+ code = "a=std::complex<double>(2.,2.);"
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1.+1j
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'abc'
+ test(b)
+ except TypeError:
+ pass
+
+ def check_complex_return(self,level=5):
+ mod_name = sys._getframe().f_code.co_name + self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 1.+1j
+ code = """
+ a= a + std::complex<double>(2.,2.);
+ return_val = PyComplex_FromDoubles(a.real(),a.imag());
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=1.+1j
+ c = test(b)
+ assert( c == 3.+3j)
+
+#----------------------------------------------------------------------------
+# File conversion tests
+#----------------------------------------------------------------------------
+
+class test_file_converter(unittest.TestCase):
+ compiler = ''
+ def check_py_to_file(self,level=5):
+ import tempfile
+ file_name = tempfile.mktemp()
+ file = open(file_name,'w')
+ code = """
+ fprintf(file,"hello bob");
+ """
+ inline_tools.inline(code,['file'],compiler=self.compiler,force=1)
+ file.close()
+ file = open(file_name,'r')
+ assert(file.read() == "hello bob")
+ def check_file_to_py(self,level=5):
+ import tempfile
+ file_name = tempfile.mktemp()
+ # not sure I like Py::String as default -- might move to std::sting
+ # or just plain char*
+ code = """
+ char* _file_name = (char*) file_name.c_str();
+ FILE* file = fopen(_file_name,"w");
+ return_val = file_to_py(file,_file_name,"w");
+ """
+ file = inline_tools.inline(code,['file_name'], compiler=self.compiler,
+ force=1)
+ file.write("hello fred")
+ file.close()
+ file = open(file_name,'r')
+ assert(file.read() == "hello fred")
+
+#----------------------------------------------------------------------------
+# Instance conversion tests
+#----------------------------------------------------------------------------
+
+class test_instance_converter(unittest.TestCase):
+ pass
+
+#----------------------------------------------------------------------------
+# Callable object conversion tests
+#----------------------------------------------------------------------------
+
+class test_callable_converter(unittest.TestCase):
+ compiler=''
+ def check_call_function(self,level=5):
+ import string
+ func = string.find
+ search_str = "hello world hello"
+ sub_str = "world"
+ # * Not sure about ref counts on search_str and sub_str.
+ # * Is the Py::String necessary? (it works anyways...)
+ code = """
+ py::tuple args(2);
+ args[0] = search_str;
+ args[1] = sub_str;
+ return_val = func.call(args);
+ """
+ actual = inline_tools.inline(code,['func','search_str','sub_str'],
+ compiler=self.compiler,force=1)
+ desired = func(search_str,sub_str)
+ assert(desired == actual)
+
+class test_sequence_converter(unittest.TestCase):
+ compiler = ''
+ def check_convert_to_dict(self,level=5):
+ d = {}
+ inline_tools.inline("",['d'],compiler=self.compiler,force=1)
+ def check_convert_to_list(self,level=5):
+ l = []
+ inline_tools.inline("",['l'],compiler=self.compiler,force=1)
+ def check_convert_to_string(self,level=5):
+ s = 'hello'
+ inline_tools.inline("",['s'],compiler=self.compiler,force=1)
+ def check_convert_to_tuple(self,level=5):
+ t = ()
+ inline_tools.inline("",['t'],compiler=self.compiler,force=1)
+
+class test_string_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_string(self,level=5):
+ s = c_spec.string_converter()
+ assert( s.type_match('string') )
+ def check_type_match_int(self,level=5):
+ s = c_spec.string_converter()
+ assert(not s.type_match(5))
+ def check_type_match_float(self,level=5):
+ s = c_spec.string_converter()
+ assert(not s.type_match(5.))
+ def check_type_match_complex(self,level=5):
+ s = c_spec.string_converter()
+ assert(not s.type_match(5.+1j))
+ def check_var_in(self,level=5):
+ mod_name = 'string_var_in'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 'string'
+ code = 'a=std::string("hello");'
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+
+ exec 'from ' + mod_name + ' import test'
+ b='bub'
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 1
+ test(b)
+ except TypeError:
+ pass
+
+ def check_return(self,level=5):
+ mod_name = 'string_return'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = 'string'
+ code = """
+ a= std::string("hello");
+ return_val = PyString_FromString(a.c_str());
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b='bub'
+ c = test(b)
+ assert( c == 'hello')
+
+class test_list_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_bad(self,level=5):
+ s = c_spec.list_converter()
+ objs = [{},(),'',1,1.,1+1j]
+ for i in objs:
+ assert( not s.type_match(i) )
+ def check_type_match_good(self,level=5):
+ s = c_spec.list_converter()
+ assert(s.type_match([]))
+ def check_var_in(self,level=5):
+ mod_name = 'list_var_in'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = [1]
+ code = 'a=py::list();'
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=[1,2]
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'string'
+ test(b)
+ except TypeError:
+ pass
+
+ def check_return(self,level=5):
+ mod_name = 'list_return'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = [1]
+ code = """
+ a=py::list();
+ a.append("hello");
+ return_val = a;
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=[1,2]
+ c = test(b)
+ assert( c == ['hello'])
+
+ def check_speed(self,level=5):
+ mod_name = 'list_speed'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = range(1000000);
+ code = """
+ int v, sum = 0;
+ for(int i = 0; i < a.len(); i++)
+ {
+ v = a[i];
+ if (v % 2)
+ sum += v;
+ else
+ sum -= v;
+ }
+ return_val = sum;
+ """
+ with_cxx = ext_tools.ext_function('with_cxx',code,['a'])
+ mod.add_function(with_cxx)
+ code = """
+ int vv, sum = 0;
+ PyObject *v;
+ for(int i = 0; i < a.len(); i++)
+ {
+ v = PyList_GetItem(py_a,i);
+ //didn't set error here -- just speed test
+ vv = py_to_int(v,"list item");
+ if (vv % 2)
+ sum += vv;
+ else
+ sum -= vv;
+ }
+ return_val = sum;
+ """
+ no_checking = ext_tools.ext_function('no_checking',code,['a'])
+ mod.add_function(no_checking)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import with_cxx, no_checking'
+ import time
+ t1 = time.time()
+ sum1 = with_cxx(a)
+ t2 = time.time()
+ print 'speed test for list access'
+ print 'compiler:', self.compiler
+ print 'scxx:', t2 - t1
+ t1 = time.time()
+ sum2 = no_checking(a)
+ t2 = time.time()
+ print 'C, no checking:', t2 - t1
+ sum3 = 0
+ t1 = time.time()
+ for i in a:
+ if i % 2:
+ sum3 += i
+ else:
+ sum3 -= i
+ t2 = time.time()
+ print 'python:', t2 - t1
+ assert( sum1 == sum2 and sum1 == sum3)
+
+class test_tuple_converter(unittest.TestCase):
+ compiler = ''
+ def check_type_match_bad(self,level=5):
+ s = c_spec.tuple_converter()
+ objs = [{},[],'',1,1.,1+1j]
+ for i in objs:
+ assert( not s.type_match(i) )
+ def check_type_match_good(self,level=5):
+ s = c_spec.tuple_converter()
+ assert(s.type_match((1,)))
+ def check_var_in(self,level=5):
+ mod_name = 'tuple_var_in'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = (1,)
+ code = 'a=py::tuple();'
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=(1,2)
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'string'
+ test(b)
+ except TypeError:
+ pass
+
+ def check_return(self,level=5):
+ mod_name = 'tuple_return'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = (1,)
+ code = """
+ a=py::tuple(2);
+ a[0] = "hello";
+ a.set_item(1,py::None);
+ return_val = a;
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b=(1,2)
+ c = test(b)
+ assert( c == ('hello',None))
+
+
+class test_dict_converter(unittest.TestCase):
+ def check_type_match_bad(self,level=5):
+ s = c_spec.dict_converter()
+ objs = [[],(),'',1,1.,1+1j]
+ for i in objs:
+ assert( not s.type_match(i) )
+ def check_type_match_good(self,level=5):
+ s = c_spec.dict_converter()
+ assert(s.type_match({}))
+ def check_var_in(self,level=5):
+ mod_name = 'dict_var_in'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = {'z':1}
+ code = 'a=py::dict();' # This just checks to make sure the type is correct
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b={'y':2}
+ test(b)
+ try:
+ b = 1.
+ test(b)
+ except TypeError:
+ pass
+ try:
+ b = 'string'
+ test(b)
+ except TypeError:
+ pass
+
+ def check_return(self,level=5):
+ mod_name = 'dict_return'+self.compiler
+ mod_name = unique_mod(test_dir,mod_name)
+ mod = ext_tools.ext_module(mod_name)
+ a = {'z':1}
+ code = """
+ a=py::dict();
+ a["hello"] = 5;
+ return_val = a;
+ """
+ test = ext_tools.ext_function('test',code,['a'])
+ mod.add_function(test)
+ mod.compile(location = test_dir, compiler = self.compiler)
+ exec 'from ' + mod_name + ' import test'
+ b = {'z':2}
+ c = test(b)
+ assert( c['hello'] == 5)
+
+class test_msvc_int_converter(test_int_converter):
+ compiler = 'msvc'
+class test_unix_int_converter(test_int_converter):
+ compiler = ''
+class test_gcc_int_converter(test_int_converter):
+ compiler = 'gcc'
+
+class test_msvc_float_converter(test_float_converter):
+ compiler = 'msvc'
+
+class test_msvc_float_converter(test_float_converter):
+ compiler = 'msvc'
+class test_unix_float_converter(test_float_converter):
+ compiler = ''
+class test_gcc_float_converter(test_float_converter):
+ compiler = 'gcc'
+
+class test_msvc_complex_converter(test_complex_converter):
+ compiler = 'msvc'
+class test_unix_complex_converter(test_complex_converter):
+ compiler = ''
+class test_gcc_complex_converter(test_complex_converter):
+ compiler = 'gcc'
+
+class test_msvc_file_converter(test_file_converter):
+ compiler = 'msvc'
+class test_unix_file_converter(test_file_converter):
+ compiler = ''
+class test_gcc_file_converter(test_file_converter):
+ compiler = 'gcc'
+
+class test_msvc_callable_converter(test_callable_converter):
+ compiler = 'msvc'
+class test_unix_callable_converter(test_callable_converter):
+ compiler = ''
+class test_gcc_callable_converter(test_callable_converter):
+ compiler = 'gcc'
+
+class test_msvc_sequence_converter(test_sequence_converter):
+ compiler = 'msvc'
+class test_unix_sequence_converter(test_sequence_converter):
+ compiler = ''
+class test_gcc_sequence_converter(test_sequence_converter):
+ compiler = 'gcc'
+
+class test_msvc_string_converter(test_string_converter):
+ compiler = 'msvc'
+class test_unix_string_converter(test_string_converter):
+ compiler = ''
+class test_gcc_string_converter(test_string_converter):
+ compiler = 'gcc'
+
+class test_msvc_list_converter(test_list_converter):
+ compiler = 'msvc'
+class test_unix_list_converter(test_list_converter):
+ compiler = ''
+class test_gcc_list_converter(test_list_converter):
+ compiler = 'gcc'
+
+class test_msvc_tuple_converter(test_tuple_converter):
+ compiler = 'msvc'
+class test_unix_tuple_converter(test_tuple_converter):
+ compiler = ''
+class test_gcc_tuple_converter(test_tuple_converter):
+ compiler = 'gcc'
+
+class test_msvc_dict_converter(test_dict_converter):
+ compiler = 'msvc'
+class test_unix_dict_converter(test_dict_converter):
+ compiler = ''
+class test_gcc_dict_converter(test_dict_converter):
+ compiler = 'gcc'
+
+class test_msvc_instance_converter(test_instance_converter):
+ compiler = 'msvc'
+class test_unix_instance_converter(test_instance_converter):
+ compiler = ''
+class test_gcc_instance_converter(test_instance_converter):
+ compiler = 'gcc'
+
+def setup_test_location():
+ import tempfile
+ #test_dir = os.path.join(tempfile.gettempdir(),'test_files')
+ test_dir = tempfile.mktemp()
+ if not os.path.exists(test_dir):
+ os.mkdir(test_dir)
+ sys.path.insert(0,test_dir)
+ return test_dir
+
+test_dir = setup_test_location()
+
+def teardown_test_location():
+ import tempfile
+ test_dir = os.path.join(tempfile.gettempdir(),'test_files')
+ if sys.path[0] == test_dir:
+ sys.path = sys.path[1:]
+ return test_dir
+
+def remove_file(name):
+ test_dir = os.path.abspath(name)
+
+if not msvc_exists():
+ for _n in dir():
+ if _n[:10]=='test_msvc_': exec 'del '+_n
+else:
+ for _n in dir():
+ if _n[:10]=='test_unix_': exec 'del '+_n
+
+if not (gcc_exists() and msvc_exists() and sys.platform == 'win32'):
+ for _n in dir():
+ if _n[:9]=='test_gcc_': exec 'del '+_n
+
+if __name__ == "__main__":
+ ScipyTest('weave.c_spec').run()