summaryrefslogtreecommitdiff
path: root/weave/accelerate_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'weave/accelerate_tools.py')
-rw-r--r--weave/accelerate_tools.py79
1 files changed, 63 insertions, 16 deletions
diff --git a/weave/accelerate_tools.py b/weave/accelerate_tools.py
index c62ee6f97..8b45b16b4 100644
--- a/weave/accelerate_tools.py
+++ b/weave/accelerate_tools.py
@@ -9,11 +9,11 @@ C++ equivalents to Python functions.
"""
#**************************************************************************#
-from types import FunctionType,IntType,FloatType,StringType,TypeType
+from types import FunctionType,IntType,FloatType,StringType,TypeType,XRangeType
import inspect
import md5
import weave
-import bytecodecompiler
+from bytecodecompiler import CXXCoder,Type_Descriptor,Function_Descriptor
def CStr(s):
"Hacky way to get legal C string from Python string"
@@ -22,7 +22,7 @@ def CStr(s):
r = repr('"'+s) # Better for embedded quotes
return '"'+r[2:-1]+'"'
-class Basic(bytecodecompiler.Type_Descriptor):
+class Basic(Type_Descriptor):
def check(self,s):
return "%s(%s)"%(self.checker,s)
def inbound(self,s):
@@ -60,10 +60,12 @@ class String(Basic):
import Numeric
-class Vector(bytecodecompiler.Type_Descriptor):
+class Vector(Type_Descriptor):
cxxtype = 'PyArrayObject*'
refcount = 1
- prerequisites = bytecodecompiler.Type_Descriptor.prerequisites+\
+ module_init_code = 'import_array'
+
+ prerequisites = Type_Descriptor.prerequisites+\
['#include "Numeric/arrayobject.h"']
dims = 1
def check(self,s):
@@ -74,19 +76,53 @@ class Vector(bytecodecompiler.Type_Descriptor):
def outbound(self,s):
return "(PyObject*)(%s)"%s
def binopMixed(self,symbol,a,b):
- return '*((%s*)(%s->data+%s*%s->strides[0]))'%(self.cxxbase,a,b,a),Integer()
+ return '*((%s*)(%s->data+%s*%s->strides[0]))'%(self.cxxbase,a,b,a),Integer
class IntegerVector(Vector):
typecode = 'PyArray_INT'
cxxbase = 'int'
+class XRange(Type_Descriptor):
+ cxxtype = 'XRange'
+ prerequisites = ['''
+ class XRange {
+ public:
+ XRange(long aLow, long aHigh, long aStep=1)
+ : low(aLow),high(aHigh),step(aStep)
+ {
+ }
+ long low;
+ long high;
+ long step;
+ };''']
+
+# -----------------------------------------------
+# Singletonize the type names
+# -----------------------------------------------
+Integer = Integer()
+Double = Double()
+String = String()
+IntegerVector = IntegerVector()
+XRange = XRange()
+
typedefs = {
- IntType: Integer(),
- FloatType: Double(),
- StringType: String(),
- (Numeric.ArrayType,1,'l'): IntegerVector(),
+ IntType: Integer,
+ FloatType: Double,
+ StringType: String,
+ (Numeric.ArrayType,1,'l'): IntegerVector,
+ XRangeType : XRange,
}
+import math
+functiondefs = {
+ (range,(Integer,Integer)):
+ Function_Descriptor(code='XRange(%s)',return_type=XRange),
+
+ (math.sin,(Double,)):
+ Function_Descriptor(code='sin(%s)',return_type=Double),
+ }
+
+
##################################################################
# FUNCTION LOOKUP_TYPE #
@@ -189,30 +225,41 @@ class accelerate:
##################################################################
# CLASS PYTHON2CXX #
##################################################################
-class Python2CXX(bytecodecompiler.CXXCoder):
+class Python2CXX(CXXCoder):
def typedef_by_value(self,v):
T = lookup_type(v)
- if T not in self.used: self.used.append(T)
+ if T not in self.used:
+ self.used.append(T)
return T
+ def function_by_signature(self,signature):
+ descriptor = functiondefs[signature]
+ if descriptor.return_type not in self.used:
+ self.used.append(descriptor.return_type)
+ return descriptor
+
def __init__(self,f,signature,name=None):
# Make sure function is a function
import types
assert type(f) == FunctionType
# and check the input type signature
assert reduce(lambda x,y: x and y,
- map(lambda x: isinstance(x,bytecodecompiler.Type_Descriptor),
+ map(lambda x: isinstance(x,Type_Descriptor),
signature),
1),'%s not all type objects'%signature
self.arg_specs = []
self.customize = weave.base_info.custom_info()
- self.customize.add_module_init_code('import_array()\n')
- bytecodecompiler.CXXCoder.__init__(self,f,signature,name)
+ CXXCoder.__init__(self,f,signature,name)
+
return
def function_code(self):
- return self.wrapped_code()
+ code = self.wrapped_code()
+ for T in self.used:
+ if T != None and T.module_init_code:
+ self.customize.add_module_init_code(T.module_init_code)
+ return code
def python_function_definition_code(self):
return '{ "%s", wrapper_%s, METH_VARARGS, %s },\n'%(