summaryrefslogtreecommitdiff
path: root/weave/ext_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'weave/ext_tools.py')
-rw-r--r--weave/ext_tools.py445
1 files changed, 445 insertions, 0 deletions
diff --git a/weave/ext_tools.py b/weave/ext_tools.py
new file mode 100644
index 000000000..3c18d2fb9
--- /dev/null
+++ b/weave/ext_tools.py
@@ -0,0 +1,445 @@
+import os, sys
+import string, re
+
+import build_tools
+
+import base_spec
+import scalar_spec
+import sequence_spec
+import common_spec
+
+default_type_factories = [scalar_spec.int_specification(),
+ scalar_spec.float_specification(),
+ scalar_spec.complex_specification(),
+ sequence_spec.string_specification(),
+ sequence_spec.list_specification(),
+ sequence_spec.dict_specification(),
+ sequence_spec.tuple_specification(),
+ common_spec.file_specification(),
+ common_spec.callable_specification()]
+ #common_spec.instance_specification(),
+ #common_spec.module_specification()]
+
+try:
+ from standard_array_spec import array_specification
+ default_type_factories.append(array_specification())
+except:
+ pass
+
+try:
+ # this is currently safe because it doesn't import wxPython.
+ import wx_spec
+ default_type_factories.append(wx_spec.wx_specification())
+except:
+ pass
+
+class ext_function_from_specs:
+ def __init__(self,name,code_block,arg_specs):
+ self.name = name
+ self.arg_specs = base_spec.arg_spec_list(arg_specs)
+ self.code_block = code_block
+ self.compiler = ''
+ self.customize = base_info.custom_info()
+
+ def header_code(self):
+ pass
+
+ def function_declaration_code(self):
+ code = 'static PyObject* %s(PyObject*self, PyObject* args,' \
+ ' PyObject* kywds)\n{\n'
+ return code % self.name
+
+ def template_declaration_code(self):
+ code = 'template<class T>\n' \
+ 'static PyObject* %s(PyObject*self, PyObject* args,' \
+ ' PyObject* kywds)\n{\n'
+ return code % self.name
+
+ #def cpp_function_declaration_code(self):
+ # pass
+ #def cpp_function_call_code(self):
+ #s pass
+
+ def parse_tuple_code(self):
+ """ Create code block for PyArg_ParseTuple. Variable declarations
+ for all PyObjects are done also.
+
+ This code got a lot uglier when I added local_dict...
+ """
+ join = string.join
+
+ declare_return = 'PyObject *return_val = NULL;\n' \
+ 'int exception_occured = 0;\n' \
+ 'PyObject *py_local_dict = NULL;\n'
+ arg_string_list = self.arg_specs.variable_as_strings() + ['"local_dict"']
+ arg_strings = join(arg_string_list,',')
+ if arg_strings: arg_strings += ','
+ declare_kwlist = 'static char *kwlist[] = {%s NULL};\n' % arg_strings
+
+ py_objects = join(self.arg_specs.py_pointers(),', ')
+ if py_objects:
+ declare_py_objects = 'PyObject ' + py_objects +';\n'
+ else:
+ declare_py_objects = ''
+
+ py_vars = join(self.arg_specs.py_variables(),' = ')
+ if py_vars:
+ init_values = py_vars + ' = NULL;\n\n'
+ else:
+ init_values = ''
+
+ #Each variable is in charge of its own cleanup now.
+ #cnt = len(arg_list)
+ #declare_cleanup = "blitz::TinyVector<PyObject*,%d> clean_up(0);\n" % cnt
+
+ ref_string = join(self.arg_specs.py_references(),', ')
+ if ref_string:
+ ref_string += ', &py_local_dict'
+ else:
+ ref_string = '&py_local_dict'
+
+ format = "O"* len(self.arg_specs) + "|O" + ':' + self.name
+ parse_tuple = 'if(!PyArg_ParseTupleAndKeywords(args,' \
+ 'kywds,"%s",kwlist,%s))\n' % (format,ref_string)
+ parse_tuple += ' return NULL;\n'
+
+ return declare_return + declare_kwlist + declare_py_objects \
+ + init_values + parse_tuple
+
+ def arg_declaration_code(self):
+ arg_strings = []
+ for arg in self.arg_specs:
+ arg_strings.append(arg.declaration_code())
+ code = string.join(arg_strings,"")
+ return code
+
+ def arg_cleanup_code(self):
+ arg_strings = []
+ for arg in self.arg_specs:
+ arg_strings.append(arg.cleanup_code())
+ code = string.join(arg_strings,"")
+ return code
+
+ def arg_local_dict_code(self):
+ arg_strings = []
+ for arg in self.arg_specs:
+ arg_strings.append(arg.local_dict_code())
+ code = string.join(arg_strings,"")
+ return code
+
+ def function_code(self):
+ decl_code = indent(self.arg_declaration_code(),4)
+ cleanup_code = indent(self.arg_cleanup_code(),4)
+ function_code = indent(self.code_block,4)
+ local_dict_code = indent(self.arg_local_dict_code(),4)
+
+ dict_code = "if(py_local_dict) \n" \
+ "{ \n" \
+ " Py::Dict local_dict = Py::Dict(py_local_dict); \n" + \
+ local_dict_code + \
+ "} \n"
+
+ try_code = "try \n" \
+ "{ \n" + \
+ decl_code + \
+ " /*<function call here>*/ \n" + \
+ function_code + \
+ indent(dict_code,4) + \
+ "\n} \n"
+ catch_code = "catch( Py::Exception& e) \n" \
+ "{ \n" + \
+ " return_val = Py::Null(); \n" \
+ " exception_occured = 1; \n" \
+ "} \n"
+
+ return_code = " /*cleanup code*/ \n" + \
+ cleanup_code + \
+ " if(!return_val && !exception_occured)\n" \
+ " {\n \n" \
+ " Py_INCREF(Py_None); \n" \
+ " return_val = Py_None; \n" \
+ " }\n \n" \
+ " return return_val; \n" \
+ "} \n"
+
+ all_code = self.function_declaration_code() + \
+ indent(self.parse_tuple_code(),4) + \
+ indent(try_code,4) + \
+ indent(catch_code,4) + \
+ return_code
+
+ return all_code
+
+ def python_function_definition_code(self):
+ args = (self.name, self.name)
+ function_decls = '{"%s",(PyCFunction)%s , METH_VARARGS|' \
+ 'METH_KEYWORDS},\n' % args
+ return function_decls
+
+ def set_compiler(self,compiler):
+ self.compiler = compiler
+ for arg in self.arg_specs:
+ arg.set_compiler(compiler)
+
+
+class ext_function(ext_function_from_specs):
+ def __init__(self,name,code_block, args, local_dict=None, global_dict=None,
+ auto_downcast=1, type_factories=None):
+
+ call_frame = sys._getframe().f_back
+ if local_dict is None:
+ local_dict = call_frame.f_locals
+ if global_dict is None:
+ global_dict = call_frame.f_globals
+ if type_factories is None:
+ type_factories = default_type_factories
+ arg_specs = assign_variable_types(args,local_dict, global_dict,
+ auto_downcast, type_factories)
+ ext_function_from_specs.__init__(self,name,code_block,arg_specs)
+
+
+import base_info, common_info, cxx_info, scalar_info
+
+class ext_module:
+ def __init__(self,name,compiler=''):
+ standard_info = [common_info.basic_module_info(),
+ common_info.file_info(),
+ common_info.instance_info(),
+ common_info.callable_info(),
+ common_info.module_info(),
+ cxx_info.cxx_info(),
+ scalar_info.scalar_info()]
+ self.name = name
+ self.functions = []
+ self.compiler = compiler
+ self.customize = base_info.custom_info()
+ self._build_information = base_info.info_list(standard_info)
+
+ def add_function(self,func):
+ self.functions.append(func)
+ def module_code(self):
+ code = self.warning_code() + \
+ self.header_code() + \
+ self.support_code() + \
+ self.function_code() + \
+ self.python_function_definition_code() + \
+ self.module_init_code()
+ return code
+
+ def arg_specs(self):
+ all_arg_specs = base_spec.arg_spec_list()
+ for func in self.functions:
+ all_arg_specs += func.arg_specs
+ return all_arg_specs
+
+ def build_information(self):
+ info = [self.customize] + self._build_information + \
+ self.arg_specs().build_information()
+ for func in self.functions:
+ info.append(func.customize)
+ #redundant, but easiest place to make sure compiler is set
+ for i in info:
+ i.set_compiler(self.compiler)
+ return info
+
+ def get_headers(self):
+ all_headers = self.build_information().headers()
+
+ # blitz/array.h always needs to be first so we hack that here...
+ if '"blitz/array.h"' in all_headers:
+ all_headers.remove('"blitz/array.h"')
+ all_headers.insert(0,'"blitz/array.h"')
+ return all_headers
+
+ def warning_code(self):
+ all_warnings = self.build_information().warnings()
+ w=map(lambda x: "#pragma warning(%s)\n" % x,all_warnings)
+ return ''.join(w)
+
+ def header_code(self):
+ h = self.get_headers()
+ h= map(lambda x: '#include ' + x + '\n',h)
+ return ''.join(h)
+
+ def support_code(self):
+ code = self.build_information().support_code()
+ return ''.join(code)
+
+ def function_code(self):
+ all_function_code = ""
+ for func in self.functions:
+ all_function_code += func.function_code()
+ return ''.join(all_function_code)
+
+ def python_function_definition_code(self):
+ all_definition_code = ""
+ for func in self.functions:
+ all_definition_code += func.python_function_definition_code()
+ all_definition_code = indent(''.join(all_definition_code),4)
+ code = 'static PyMethodDef compiled_methods[] = \n' \
+ '{\n' \
+ '%s' \
+ ' {NULL, NULL} /* Sentinel */\n' \
+ '};\n'
+ return code % (all_definition_code)
+
+ def module_init_code(self):
+ init_code_list = self.build_information().module_init_code()
+ init_code = indent(''.join(init_code_list),4)
+ code = 'extern "C" void init%s()\n' \
+ '{\n' \
+ '%s' \
+ ' (void) Py_InitModule("%s", compiled_methods);\n' \
+ '}\n' % (self.name,init_code,self.name)
+ return code
+
+ def generate_file(self,file_name="",location='.'):
+ code = self.module_code()
+ if not file_name:
+ file_name = self.name + '.cpp'
+ name = generate_file_name(file_name,location)
+ #return name
+ return generate_module(code,name)
+
+ def set_compiler(self,compiler):
+ #for i in self.arg_specs()
+ # i.set_compiler(compiler)
+ for i in self.build_information():
+ i.set_compiler(compiler)
+ for i in self.functions:
+ i.set_compiler(compiler)
+ self.compiler = compiler
+
+ def compile(self,location='.',compiler=None, verbose = 0, **kw):
+
+ if compiler is not None:
+ self.compiler = compiler
+ # hmm. Is there a cleaner way to do this? Seems like
+ # choosing the compiler spagettis around a little.
+ compiler = build_tools.choose_compiler(self.compiler)
+ self.set_compiler(compiler)
+ arg_specs = self.arg_specs()
+ info = self.build_information()
+ _source_files = info.sources()
+ # remove duplicates
+ source_files = {}
+ for i in _source_files:
+ source_files[i] = None
+ source_files = source_files.keys()
+
+ # add internally specified macros, includes, etc. to the key words
+ # values of the same names so that distutils will use them.
+ kw['define_macros'] = kw.get('define_macros',[]) + info.define_macros()
+ kw['include_dirs'] = kw.get('include_dirs',[]) + info.include_dirs()
+ kw['libraries'] = kw.get('libraries',[]) + info.libraries()
+ kw['library_dirs'] = kw.get('library_dirs',[]) + info.library_dirs()
+
+ file = self.generate_file(location=location)
+ # This is needed so that files build correctly even when different
+ # versions of Python are running around.
+ import catalog
+ temp = catalog.default_temp_dir()
+ success = build_tools.build_extension(file, temp_dir = temp,
+ sources = source_files,
+ compiler_name = compiler,
+ verbose = verbose, **kw)
+ if not success:
+ raise SystemError, 'Compilation failed'
+
+def generate_file_name(module_name,module_location):
+ module_file = os.path.join(module_location,module_name)
+ return os.path.abspath(module_file)
+
+def generate_module(module_string, module_file):
+ f =open(module_file,'w')
+ f.write(module_string)
+ f.close()
+ return module_file
+
+def assign_variable_types(variables,local_dict = {}, global_dict = {},
+ auto_downcast = 1,
+ type_factories = default_type_factories):
+ incoming_vars = {}
+ incoming_vars.update(global_dict)
+ incoming_vars.update(local_dict)
+ variable_specs = []
+ errors={}
+ for var in variables:
+ try:
+ example_type = incoming_vars[var]
+
+ # look through possible type specs to find which one
+ # should be used to for example_type
+ spec = None
+ for factory in type_factories:
+ if factory.type_match(example_type):
+ spec = factory.type_spec(var,example_type)
+ break
+ if not spec:
+ # should really define our own type.
+ raise IndexError
+ else:
+ variable_specs.append(spec)
+ except KeyError:
+ errors[var] = ("The type and dimensionality specifications" +
+ "for variable '" + var + "' are missing.")
+ except IndexError:
+ errors[var] = ("Unable to convert variable '"+ var +
+ "' to a C++ type.")
+ if errors:
+ raise TypeError, format_error_msg(errors)
+
+ if auto_downcast:
+ variable_specs = downcast(variable_specs)
+ return variable_specs
+
+def downcast(var_specs):
+ """ Cast python scalars down to most common type of
+ arrays used.
+
+ Right now, focus on complex and float types. Ignore int types.
+ Require all arrays to have same type before forcing downcasts.
+
+ Note: var_specs are currently altered in place (horrors...!)
+ """
+ numeric_types = []
+
+ #grab all the numeric types associated with a variables.
+ for var in var_specs:
+ if hasattr(var,'numeric_type'):
+ numeric_types.append(var.numeric_type)
+
+ # if arrays are present, but none of them are double precision,
+ # make all numeric types float or complex(float)
+ if ( ('f' in numeric_types or 'F' in numeric_types) and
+ not ('d' in numeric_types or 'D' in numeric_types) ):
+ for var in var_specs:
+ if hasattr(var,'numeric_type'):
+ # really should do this some other way...
+ if var.numeric_type == type(1+1j):
+ var.numeric_type = 'F'
+ elif var.numeric_type == type(1.):
+ var.numeric_type = 'f'
+ return var_specs
+
+def indent(st,spaces):
+ indention = ' '*spaces
+ indented = indention + string.replace(st,'\n','\n'+indention)
+ # trim off any trailing spaces
+ indented = re.sub(r' +$',r'',indented)
+ return indented
+
+def format_error_msg(errors):
+ #minimum effort right now...
+ import pprint,cStringIO
+ msg = cStringIO.StringIO()
+ pprint.pprint(errors,msg)
+ return msg.getvalue()
+
+def test():
+ from scipy_test import module_test
+ module_test(__name__,__file__)
+
+def test_suite():
+ from scipy_test import module_test_suite
+ return module_test_suite(__name__,__file__)