summaryrefslogtreecommitdiff
path: root/weave/blitz_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'weave/blitz_tools.py')
-rw-r--r--weave/blitz_tools.py176
1 files changed, 176 insertions, 0 deletions
diff --git a/weave/blitz_tools.py b/weave/blitz_tools.py
new file mode 100644
index 000000000..2a578b01b
--- /dev/null
+++ b/weave/blitz_tools.py
@@ -0,0 +1,176 @@
+import parser
+import string
+import copy
+import os,sys
+import ast_tools
+import token,symbol
+import slice_handler
+import size_check
+
+from ast_tools import *
+from Numeric import *
+from fastumath import *
+from types import *
+
+import inline_tools
+from inline_tools import attempt_function_call
+function_catalog = inline_tools.function_catalog
+function_cache = inline_tools.function_cache
+
+# this is pretty much the same as the default factories.
+# We've just replaced the array specification with the blitz
+# specification
+import base_spec
+import scalar_spec
+import sequence_spec
+import common_spec
+from blitz_spec import array_specification
+blitz_type_factories = [sequence_spec.string_specification(),
+ sequence_spec.list_specification(),
+ sequence_spec.dict_specification(),
+ sequence_spec.tuple_specification(),
+ scalar_spec.int_specification(),
+ scalar_spec.float_specification(),
+ scalar_spec.complex_specification(),
+ common_spec.file_specification(),
+ common_spec.callable_specification(),
+ array_specification()]
+ #common_spec.instance_specification(),
+ #common_spec.module_specification()]
+
+try:
+ # this is currently safe because it doesn't import wxPython.
+ import wx_spec
+ default_type_factories.append(wx_spec.wx_specification())
+except:
+ pass
+
+def blitz(expr,local_dict=None, global_dict=None,check_size=1,verbose=0):
+ # this could call inline, but making a copy of the
+ # code here is more efficient for several reasons.
+ global function_catalog
+
+ # this grabs the local variables from the *previous* call
+ # frame -- that is the locals from the function that called
+ # inline.
+ call_frame = sys._getframe().f_back
+ if local_dict is None:
+ local_dict = call_frame.f_locals
+ if global_dict is None:
+ global_dict = call_frame.f_globals
+
+ # 1. Check the sizes of the arrays and make sure they are compatible.
+ # This is expensive, so unsetting the check_size flag can save a lot
+ # of time. It also can cause core-dumps if the sizes of the inputs
+ # aren't compatible.
+ if check_size and not size_check.check_expr(expr,local_dict,global_dict):
+ raise 'inputs failed to pass size check.'
+
+ # 2. try local cache
+ try:
+ results = apply(function_cache[expr],(local_dict,global_dict))
+ return results
+ except:
+ pass
+ try:
+ results = attempt_function_call(expr,local_dict,global_dict)
+ # 3. build the function
+ except ValueError:
+ # This section is pretty much the only difference
+ # between blitz and inline
+ ast = parser.suite(expr)
+ ast_list = ast.tolist()
+ expr_code = ast_to_blitz_expr(ast_list)
+ arg_names = harvest_variables(ast_list)
+ module_dir = global_dict.get('__file__',None)
+ #func = inline_tools.compile_function(expr_code,arg_names,
+ # local_dict,global_dict,
+ # module_dir,auto_downcast = 1)
+ func = inline_tools.compile_function(expr_code,arg_names,local_dict,
+ global_dict,module_dir,
+ compiler='gcc',auto_downcast=1,
+ verbose = verbose,
+ type_factories = blitz_type_factories)
+ function_catalog.add_function(expr,func,module_dir)
+ try:
+ results = attempt_function_call(expr,local_dict,global_dict)
+ except ValueError:
+ print 'warning: compilation failed. Executing as python code'
+ exec expr in global_dict, local_dict
+
+def ast_to_blitz_expr(ast_seq):
+ """ Convert an ast_sequence to a blitz expression.
+ """
+
+ # Don't overwrite orignal sequence in call to transform slices.
+ ast_seq = copy.deepcopy(ast_seq)
+ slice_handler.transform_slices(ast_seq)
+
+ # Build the actual program statement from ast_seq
+ expr = ast_tools.ast_to_string(ast_seq)
+
+ # Now find and replace specific symbols to convert this to
+ # a blitz++ compatible statement.
+ # I'm doing this with string replacement here. It could
+ # also be done on the actual ast tree (and probably should from
+ # a purest standpoint...).
+
+ # this one isn't necessary but it helps code readability
+ # and compactness. It requires that
+ # Range _all = blitz::Range::all();
+ # be included in the generated code.
+ # These could all alternatively be done to the ast in
+ # build_slice_atom()
+ expr = string.replace(expr,'slice(_beg,_end)', '_all' )
+ expr = string.replace(expr,'slice', 'blitz::Range' )
+ expr = string.replace(expr,'[','(')
+ expr = string.replace(expr,']', ')' )
+ expr = string.replace(expr,'_stp', '1' )
+
+ # Instead of blitz::fromStart and blitz::toEnd. This requires
+ # the following in the generated code.
+ # Range _beg = blitz::fromStart;
+ # Range _end = blitz::toEnd;
+ #expr = string.replace(expr,'_beg', 'blitz::fromStart' )
+ #expr = string.replace(expr,'_end', 'blitz::toEnd' )
+
+ return expr + ';\n'
+
+def test_function():
+ from code_blocks import module_header
+
+ expr = "ex[:,1:,1:] = k + ca_x[:,1:,1:] * ex[:,1:,1:]" \
+ "+ cb_y_x[:,1:,1:] * (hz[:,1:,1:] - hz[:,:-1,1:])"\
+ "- cb_z_x[:,1:,1:] * (hy[:,1:,1:] - hy[:,1:,:-1])"
+ #ast = parser.suite('a = (b + c) * sin(d)')
+ ast = parser.suite(expr)
+ k = 1.
+ ex = ones((1,1,1),typecode=Float32)
+ ca_x = ones((1,1,1),typecode=Float32)
+ cb_y_x = ones((1,1,1),typecode=Float32)
+ cb_z_x = ones((1,1,1),typecode=Float32)
+ hz = ones((1,1,1),typecode=Float32)
+ hy = ones((1,1,1),typecode=Float32)
+ blitz(expr)
+ """
+ ast_list = ast.tolist()
+
+ expr_code = ast_to_blitz_expr(ast_list)
+ arg_list = harvest_variables(ast_list)
+ arg_specs = assign_variable_types(arg_list,locals())
+
+ func,template_types = create_function('test_function',expr_code,arg_list,arg_specs)
+ init,used_names = create_module_init('compile_sample','test_function',template_types)
+ #wrapper = create_wrapper(mod_name,func_name,used_names)
+ return string.join( [module_header,func,init],'\n')
+ """
+def test():
+ from scipy_test import module_test
+ module_test(__name__,__file__)
+
+def test_suite():
+ from scipy_test import module_test_suite
+ return module_test_suite(__name__,__file__)
+
+if __name__ == "__main__":
+ test_function() \ No newline at end of file