summaryrefslogtreecommitdiff
path: root/scipy/weave/accelerate_tools.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2005-09-26 20:20:16 +0000
committerTravis Oliphant <oliphant@enthought.com>2005-09-26 20:20:16 +0000
commit45d01a4be1c4221132ba46d687e6af3a8df3329b (patch)
treece3be5290e918def7c7187e747c5460193b0ca85 /scipy/weave/accelerate_tools.py
parentccd1c3db37672627aa4fe0fdb5437f5dddc0fe86 (diff)
downloadnumpy-45d01a4be1c4221132ba46d687e6af3a8df3329b.tar.gz
Moved weave
Diffstat (limited to 'scipy/weave/accelerate_tools.py')
-rw-r--r--scipy/weave/accelerate_tools.py412
1 files changed, 412 insertions, 0 deletions
diff --git a/scipy/weave/accelerate_tools.py b/scipy/weave/accelerate_tools.py
new file mode 100644
index 000000000..0b6fc581a
--- /dev/null
+++ b/scipy/weave/accelerate_tools.py
@@ -0,0 +1,412 @@
+#**************************************************************************#
+#* FILE ************** accelerate_tools.py ************************#
+#**************************************************************************#
+#* Author: Patrick Miller February 9 2002 *#
+#**************************************************************************#
+"""
+accelerate_tools contains the interface for on-the-fly building of
+C++ equivalents to Python functions.
+"""
+#**************************************************************************#
+
+from types import InstanceType,FunctionType,IntType,FloatType,StringType,TypeType,XRangeType
+import inspect
+import md5
+import weave
+import imp
+from bytecodecompiler import CXXCoder,Type_Descriptor,Function_Descriptor
+
+def CStr(s):
+ "Hacky way to get legal C string from Python string"
+ if s is None: return '""'
+ assert type(s) == StringType,"Only None and string allowed"
+ r = repr('"'+s) # Better for embedded quotes
+ return '"'+r[2:-1]+'"'
+
+
+##################################################################
+# CLASS INSTANCE #
+##################################################################
+class Instance(Type_Descriptor):
+ cxxtype = 'PyObject*'
+
+ def __init__(self,prototype):
+ self.prototype = prototype
+ return
+
+ def check(self,s):
+ return "PyInstance_Check(%s)"%s
+
+ def inbound(self,s):
+ return s
+
+ def outbound(self,s):
+ return s,0
+
+ def get_attribute(self,name):
+ proto = getattr(self.prototype,name)
+ T = lookup_type(proto)
+ code = 'tempPY = PyObject_GetAttrString(%%(rhs)s,"%s");\n'%name
+ convert = T.inbound('tempPY')
+ code += '%%(lhsType)s %%(lhs)s = %s;\n'%convert
+ return T,code
+
+ def set_attribute(self,name):
+ proto = getattr(self.prototype,name)
+ T = lookup_type(proto)
+ convert,owned = T.outbound('%(rhs)s')
+ code = 'tempPY = %s;'%convert
+ if not owned:
+ code += ' Py_INCREF(tempPY);'
+ code += ' PyObject_SetAttrString(%%(lhs)s,"%s",tempPY);'%name
+ code += ' Py_DECREF(tempPY);\n'
+ return T,code
+
+##################################################################
+# CLASS BASIC #
+##################################################################
+class Basic(Type_Descriptor):
+ owned = 1
+ def check(self,s):
+ return "%s(%s)"%(self.checker,s)
+ def inbound(self,s):
+ return "%s(%s)"%(self.inbounder,s)
+ def outbound(self,s):
+ return "%s(%s)"%(self.outbounder,s),self.owned
+
+class Basic_Number(Basic):
+ def literalizer(self,s):
+ return str(s)
+ def binop(self,symbol,a,b):
+ assert symbol in ['+','-','*','/'],symbol
+ return '%s %s %s'%(a,symbol,b),self
+
+class Integer(Basic_Number):
+ cxxtype = "long"
+ checker = "PyInt_Check"
+ inbounder = "PyInt_AsLong"
+ outbounder = "PyInt_FromLong"
+
+class Double(Basic_Number):
+ cxxtype = "double"
+ checker = "PyFloat_Check"
+ inbounder = "PyFloat_AsDouble"
+ outbounder = "PyFloat_FromDouble"
+
+class String(Basic):
+ cxxtype = "char*"
+ checker = "PyString_Check"
+ inbounder = "PyString_AsString"
+ outbounder = "PyString_FromString"
+
+ def literalizer(self,s):
+ return CStr(s)
+
+# -----------------------------------------------
+# Singletonize the type names
+# -----------------------------------------------
+Integer = Integer()
+Double = Double()
+String = String()
+
+import scipy.base as nx
+
+class Vector(Type_Descriptor):
+ cxxtype = 'PyArrayObject*'
+ refcount = 1
+ dims = 1
+ module_init_code = 'import_array();\n'
+ inbounder = "(PyArrayObject*)"
+ outbounder = "(PyObject*)"
+ owned = 0 # Convertion is by casting!
+
+ prerequisites = Type_Descriptor.prerequisites+\
+ ['#include "scipy/arrayobject.h"']
+ dims = 1
+ def check(self,s):
+ return "PyArray_Check(%s) && ((PyArrayObject*)%s)->nd == %d && ((PyArrayObject*)%s)->descr->type_num == %s"%(
+ s,s,self.dims,s,self.typecode)
+
+ def inbound(self,s):
+ return "%s(%s)"%(self.inbounder,s)
+ def outbound(self,s):
+ return "%s(%s)"%(self.outbounder,s),self.owned
+
+ def getitem(self,A,v,t):
+ assert self.dims == len(v),'Expect dimension %d'%self.dims
+ code = '*((%s*)(%s->data'%(self.cxxbase,A)
+ for i in range(self.dims):
+ # assert that ''t[i]'' is an integer
+ code += '+%s*%s->strides[%d]'%(v[i],A,i)
+ code += '))'
+ return code,self.pybase
+ def setitem(self,A,v,t):
+ return self.getitem(A,v,t)
+
+class matrix(Vector):
+ dims = 2
+
+class IntegerVector(Vector):
+ typecode = 'PyArray_INT'
+ cxxbase = 'int'
+ pybase = Integer
+
+class Integermatrix(matrix):
+ typecode = 'PyArray_INT'
+ cxxbase = 'int'
+ pybase = Integer
+
+class LongVector(Vector):
+ typecode = 'PyArray_LONG'
+ cxxbase = 'long'
+ pybase = Integer
+
+class Longmatrix(matrix):
+ typecode = 'PyArray_LONG'
+ cxxbase = 'long'
+ pybase = Integer
+
+class DoubleVector(Vector):
+ typecode = 'PyArray_DOUBLE'
+ cxxbase = 'double'
+ pybase = Double
+
+class Doublematrix(matrix):
+ typecode = 'PyArray_DOUBLE'
+ cxxbase = 'double'
+ pybase = Double
+
+
+##################################################################
+# CLASS XRANGE #
+##################################################################
+class XRange(Type_Descriptor):
+ cxxtype = 'XRange'
+ prerequisites = ['''
+ class XRange {
+ public:
+ XRange(long aLow, long aHigh, long aStep=1)
+ : low(aLow),high(aHigh),step(aStep)
+ {
+ }
+ XRange(long aHigh)
+ : low(0),high(aHigh),step(1)
+ {
+ }
+ long low;
+ long high;
+ long step;
+ };''']
+
+# -----------------------------------------------
+# Singletonize the type names
+# -----------------------------------------------
+IntegerVector = IntegerVector()
+Integermatrix = Integermatrix()
+LongVector = LongVector()
+Longmatrix = Longmatrix()
+DoubleVector = DoubleVector()
+Doublematrix = Doublematrix()
+XRange = XRange()
+
+
+typedefs = {
+ IntType: Integer,
+ FloatType: Double,
+ StringType: String,
+ (nx.ArrayType,1,'i'): IntegerVector,
+ (nx.ArrayType,2,'i'): Integermatrix,
+ (nx.ArrayType,1,'l'): LongVector,
+ (nx.ArrayType,2,'l'): Longmatrix,
+ (nx.ArrayType,1,'d'): DoubleVector,
+ (nx.ArrayType,2,'d'): Doublematrix,
+ XRangeType : XRange,
+ }
+
+import math
+functiondefs = {
+ (len,(String,)):
+ Function_Descriptor(code='strlen(%s)',return_type=Integer),
+
+ (len,(LongVector,)):
+ Function_Descriptor(code='PyArray_Size((PyObject*)%s)',return_type=Integer),
+
+ (float,(Integer,)):
+ Function_Descriptor(code='(double)(%s)',return_type=Double),
+
+ (range,(Integer,Integer)):
+ Function_Descriptor(code='XRange(%s)',return_type=XRange),
+
+ (range,(Integer)):
+ Function_Descriptor(code='XRange(%s)',return_type=XRange),
+
+ (math.sin,(Double,)):
+ Function_Descriptor(code='sin(%s)',return_type=Double),
+
+ (math.cos,(Double,)):
+ Function_Descriptor(code='cos(%s)',return_type=Double),
+
+ (math.sqrt,(Double,)):
+ Function_Descriptor(code='sqrt(%s)',return_type=Double),
+ }
+
+
+
+##################################################################
+# FUNCTION LOOKUP_TYPE #
+##################################################################
+def lookup_type(x):
+ T = type(x)
+ try:
+ return typedefs[T]
+ except:
+ import scipy.base as nx
+ if isinstance(T,nx.ArrayType):
+ return typedefs[(T,len(x.shape),x.typecode())]
+ elif T == InstanceType:
+ return Instance(x)
+ else:
+ raise NotImplementedError,T
+
+##################################################################
+# class ACCELERATE #
+##################################################################
+class accelerate:
+
+ def __init__(self, function, *args, **kw):
+ assert type(function) == FunctionType
+ self.function = function
+ self.module = inspect.getmodule(function)
+ if self.module is None:
+ import __main__
+ self.module = __main__
+ self.__call_map = {}
+
+ def __cache(self,*args):
+ raise TypeError
+
+ def __call__(self,*args):
+ try:
+ return self.__cache(*args)
+ except TypeError:
+ # Figure out type info -- Do as tuple so its hashable
+ signature = tuple( map(lookup_type,args) )
+
+ # If we know the function, call it
+ try:
+ fast = self.__call_map[signature]
+ except:
+ fast = self.singleton(signature)
+ self.__cache = fast
+ self.__call_map[signature] = fast
+ return fast(*args)
+
+ def signature(self,*args):
+ # Figure out type info -- Do as tuple so its hashable
+ signature = tuple( map(lookup_type,args) )
+ return self.singleton(signature)
+
+
+ def singleton(self,signature):
+ identifier = self.identifier(signature)
+
+ # Generate a new function, then call it
+ f = self.function
+
+ # See if we have an accelerated version of module
+ try:
+ print 'lookup',self.module.__name__+'_weave'
+ accelerated_module = __import__(self.module.__name__+'_weave')
+ print 'have accelerated',self.module.__name__+'_weave'
+ fast = getattr(accelerated_module,identifier)
+ return fast
+ except ImportError:
+ accelerated_module = None
+ except AttributeError:
+ pass
+
+ P = self.accelerate(signature,identifier)
+
+ E = weave.ext_tools.ext_module(self.module.__name__+'_weave')
+ E.add_function(P)
+ E.generate_file()
+ weave.build_tools.build_extension(self.module.__name__+'_weave.cpp',verbose=2)
+
+ if accelerated_module:
+ raise NotImplementedError,'Reload'
+ else:
+ accelerated_module = __import__(self.module.__name__+'_weave')
+
+ fast = getattr(accelerated_module,identifier)
+ return fast
+
+ def identifier(self,signature):
+ # Build an MD5 checksum
+ f = self.function
+ co = f.func_code
+ identifier = str(signature)+\
+ str(co.co_argcount)+\
+ str(co.co_consts)+\
+ str(co.co_varnames)+\
+ co.co_code
+ return 'F'+md5.md5(identifier).hexdigest()
+
+ def accelerate(self,signature,identifier):
+ P = Python2CXX(self.function,signature,name=identifier)
+ return P
+
+ def code(self,*args):
+ if len(args) != self.function.func_code.co_argcount:
+ raise TypeError,'%s() takes exactly %d arguments (%d given)'%(
+ self.function.__name__,
+ self.function.func_code.co_argcount,
+ len(args))
+ signature = tuple( map(lookup_type,args) )
+ ident = self.function.__name__
+ return self.accelerate(signature,ident).function_code()
+
+
+##################################################################
+# CLASS PYTHON2CXX #
+##################################################################
+class Python2CXX(CXXCoder):
+ def typedef_by_value(self,v):
+ T = lookup_type(v)
+ if T not in self.used:
+ self.used.append(T)
+ return T
+
+ def function_by_signature(self,signature):
+ descriptor = functiondefs[signature]
+ if descriptor.return_type not in self.used:
+ self.used.append(descriptor.return_type)
+ return descriptor
+
+ def __init__(self,f,signature,name=None):
+ # Make sure function is a function
+ import types
+ assert type(f) == FunctionType
+ # and check the input type signature
+ assert reduce(lambda x,y: x and y,
+ map(lambda x: isinstance(x,Type_Descriptor),
+ signature),
+ 1),'%s not all type objects'%signature
+ self.arg_specs = []
+ self.customize = weave.base_info.custom_info()
+
+ CXXCoder.__init__(self,f,signature,name)
+
+ return
+
+ def function_code(self):
+ code = self.wrapped_code()
+ for T in self.used:
+ if T != None and T.module_init_code:
+ self.customize.add_module_init_code(T.module_init_code)
+ return code
+
+ def python_function_definition_code(self):
+ return '{ "%s", wrapper_%s, METH_VARARGS, %s },\n'%(
+ self.name,
+ self.name,
+ CStr(self.function.__doc__))