diff options
Diffstat (limited to 'weave/accelerate_tools.py')
-rw-r--r-- | weave/accelerate_tools.py | 79 |
1 files changed, 63 insertions, 16 deletions
diff --git a/weave/accelerate_tools.py b/weave/accelerate_tools.py index c62ee6f97..8b45b16b4 100644 --- a/weave/accelerate_tools.py +++ b/weave/accelerate_tools.py @@ -9,11 +9,11 @@ C++ equivalents to Python functions. """ #**************************************************************************# -from types import FunctionType,IntType,FloatType,StringType,TypeType +from types import FunctionType,IntType,FloatType,StringType,TypeType,XRangeType import inspect import md5 import weave -import bytecodecompiler +from bytecodecompiler import CXXCoder,Type_Descriptor,Function_Descriptor def CStr(s): "Hacky way to get legal C string from Python string" @@ -22,7 +22,7 @@ def CStr(s): r = repr('"'+s) # Better for embedded quotes return '"'+r[2:-1]+'"' -class Basic(bytecodecompiler.Type_Descriptor): +class Basic(Type_Descriptor): def check(self,s): return "%s(%s)"%(self.checker,s) def inbound(self,s): @@ -60,10 +60,12 @@ class String(Basic): import Numeric -class Vector(bytecodecompiler.Type_Descriptor): +class Vector(Type_Descriptor): cxxtype = 'PyArrayObject*' refcount = 1 - prerequisites = bytecodecompiler.Type_Descriptor.prerequisites+\ + module_init_code = 'import_array' + + prerequisites = Type_Descriptor.prerequisites+\ ['#include "Numeric/arrayobject.h"'] dims = 1 def check(self,s): @@ -74,19 +76,53 @@ class Vector(bytecodecompiler.Type_Descriptor): def outbound(self,s): return "(PyObject*)(%s)"%s def binopMixed(self,symbol,a,b): - return '*((%s*)(%s->data+%s*%s->strides[0]))'%(self.cxxbase,a,b,a),Integer() + return '*((%s*)(%s->data+%s*%s->strides[0]))'%(self.cxxbase,a,b,a),Integer class IntegerVector(Vector): typecode = 'PyArray_INT' cxxbase = 'int' +class XRange(Type_Descriptor): + cxxtype = 'XRange' + prerequisites = [''' + class XRange { + public: + XRange(long aLow, long aHigh, long aStep=1) + : low(aLow),high(aHigh),step(aStep) + { + } + long low; + long high; + long step; + };'''] + +# ----------------------------------------------- +# Singletonize the type names +# ----------------------------------------------- +Integer = Integer() +Double = Double() +String = String() +IntegerVector = IntegerVector() +XRange = XRange() + typedefs = { - IntType: Integer(), - FloatType: Double(), - StringType: String(), - (Numeric.ArrayType,1,'l'): IntegerVector(), + IntType: Integer, + FloatType: Double, + StringType: String, + (Numeric.ArrayType,1,'l'): IntegerVector, + XRangeType : XRange, } +import math +functiondefs = { + (range,(Integer,Integer)): + Function_Descriptor(code='XRange(%s)',return_type=XRange), + + (math.sin,(Double,)): + Function_Descriptor(code='sin(%s)',return_type=Double), + } + + ################################################################## # FUNCTION LOOKUP_TYPE # @@ -189,30 +225,41 @@ class accelerate: ################################################################## # CLASS PYTHON2CXX # ################################################################## -class Python2CXX(bytecodecompiler.CXXCoder): +class Python2CXX(CXXCoder): def typedef_by_value(self,v): T = lookup_type(v) - if T not in self.used: self.used.append(T) + if T not in self.used: + self.used.append(T) return T + def function_by_signature(self,signature): + descriptor = functiondefs[signature] + if descriptor.return_type not in self.used: + self.used.append(descriptor.return_type) + return descriptor + def __init__(self,f,signature,name=None): # Make sure function is a function import types assert type(f) == FunctionType # and check the input type signature assert reduce(lambda x,y: x and y, - map(lambda x: isinstance(x,bytecodecompiler.Type_Descriptor), + map(lambda x: isinstance(x,Type_Descriptor), signature), 1),'%s not all type objects'%signature self.arg_specs = [] self.customize = weave.base_info.custom_info() - self.customize.add_module_init_code('import_array()\n') - bytecodecompiler.CXXCoder.__init__(self,f,signature,name) + CXXCoder.__init__(self,f,signature,name) + return def function_code(self): - return self.wrapped_code() + code = self.wrapped_code() + for T in self.used: + if T != None and T.module_init_code: + self.customize.add_module_init_code(T.module_init_code) + return code def python_function_definition_code(self): return '{ "%s", wrapper_%s, METH_VARARGS, %s },\n'%( |