summaryrefslogtreecommitdiff
path: root/weave/standard_array_spec.py
diff options
context:
space:
mode:
Diffstat (limited to 'weave/standard_array_spec.py')
-rw-r--r--weave/standard_array_spec.py215
1 files changed, 137 insertions, 78 deletions
diff --git a/weave/standard_array_spec.py b/weave/standard_array_spec.py
index 181b72f18..0509835a7 100644
--- a/weave/standard_array_spec.py
+++ b/weave/standard_array_spec.py
@@ -1,93 +1,152 @@
-from base_spec import base_converter
-from scalar_spec import numeric_to_c_type_mapping
+from c_spec import common_base_converter
+from c_spec import num_to_c_types
from Numeric import *
from types import *
import os
-import standard_array_info
-class array_converter(base_converter):
- _build_information = [standard_array_info.array_info()]
+num_typecode = {}
+num_typecode['c'] = 'PyArray_CHAR'
+num_typecode['1'] = 'PyArray_SBYTE'
+num_typecode['b'] = 'PyArray_UBYTE'
+num_typecode['s'] = 'PyArray_SHORT'
+num_typecode['i'] = 'PyArray_INT' # PyArray_INT has troubles ?? What does this note mean ??
+num_typecode['l'] = 'PyArray_LONG'
+num_typecode['f'] = 'PyArray_FLOAT'
+num_typecode['d'] = 'PyArray_DOUBLE'
+num_typecode['F'] = 'PyArray_CFLOAT'
+num_typecode['D'] = 'PyArray_CDOUBLE'
+
+type_check_code = \
+"""
+class numpy_type_handler
+{
+public:
+ void conversion_numpy_check_type(PyArrayObject* arr_obj, int numeric_type,
+ const char* name)
+ {
+ // Make sure input has correct numeric type.
+ // allow character and byte to match
+ // also allow int and long to match
+ int arr_type = arr_obj->descr->type_num;
+ if ( arr_type != numeric_type &&
+ !(numeric_type == PyArray_CHAR && arr_type == PyArray_SBYTE) &&
+ !(numeric_type == PyArray_SBYTE && arr_type == PyArray_CHAR) &&
+ !(numeric_type == PyArray_INT && arr_type == PyArray_LONG) &&
+ !(numeric_type == PyArray_LONG && arr_type == PyArray_INT))
+ {
+ char* type_names[13] = {"char","unsigned byte","byte", "short", "int",
+ "long", "float", "double", "complex float",
+ "complex double", "object","ntype","unkown"};
+ char msg[500];
+ sprintf(msg,"Conversion Error: received '%s' typed array instead of '%s' typed array for variable '%s'",
+ type_names[arr_type],type_names[numeric_type],name);
+ throw_error(PyExc_TypeError,msg);
+ }
+ }
- def type_match(self,value):
- return type(value) is ArrayType
+ void numpy_check_type(PyArrayObject* arr_obj, int numeric_type, const char* name)
+ {
+ // Make sure input has correct numeric type.
+ int arr_type = arr_obj->descr->type_num;
+ if ( arr_type != numeric_type &&
+ !(numeric_type == PyArray_CHAR && arr_type == PyArray_SBYTE) &&
+ !(numeric_type == PyArray_SBYTE && arr_type == PyArray_CHAR) &&
+ !(numeric_type == PyArray_INT && arr_type == PyArray_LONG) &&
+ !(numeric_type == PyArray_LONG && arr_type == PyArray_INT))
+ {
+ char* type_names[13] = {"char","unsigned byte","byte", "short", "int",
+ "long", "float", "double", "complex float",
+ "complex double", "object","ntype","unkown"};
+ char msg[500];
+ sprintf(msg,"received '%s' typed array instead of '%s' typed array for variable '%s'",
+ type_names[arr_type],type_names[numeric_type],name);
+ throw_error(PyExc_TypeError,msg);
+ }
+ }
+};
- def type_spec(self,name,value):
- # factory
- new_spec = array_converter()
- new_spec.name = name
- new_spec.numeric_type = value.typecode()
- # dims not used, but here for compatibility with blitz_spec
- new_spec.dims = len(shape(value))
- return new_spec
+numpy_type_handler x__numpy_type_handler = numpy_type_handler();
+#define conversion_numpy_check_type x__numpy_type_handler.conversion_numpy_check_type
+#define numpy_check_type x__numpy_type_handler.numpy_check_type
- def declaration_code(self,templatize = 0,inline=0):
- if inline:
- code = self.inline_decl_code()
- else:
- code = self.standard_decl_code()
- return code
+"""
+
+size_check_code = \
+"""
+class numpy_size_handler
+{
+public:
+ void conversion_numpy_check_size(PyArrayObject* arr_obj, int Ndims,
+ const char* name)
+ {
+ if (arr_obj->nd != Ndims)
+ {
+ char msg[500];
+ sprintf(msg,"Conversion Error: received '%d' dimensional array instead of '%d' dimensional array for variable '%s'",
+ arr_obj->nd,Ndims,name);
+ throw_error(PyExc_TypeError,msg);
+ }
+ }
- def inline_decl_code(self):
- type = numeric_to_c_type_mapping[self.numeric_type]
- name = self.name
- #dims = self.dims
- var_name = self.retrieve_py_variable(inline=1)
- templ = '// %(name)s array declaration\n' \
- 'py_%(name)s= %(var_name)s;\n' \
- 'PyArrayObject* %(name)s = convert_to_numpy(py_%(name)s,"%(name)s");\n' \
- 'conversion_numpy_check_type(%(name)s,py_type<%(type)s>::code,"%(name)s");\n' \
- 'int* _N%(name)s = %(name)s->dimensions;\n' \
- 'int* _S%(name)s = %(name)s->strides;\n' \
- 'int _D%(name)s = %(name)s->nd;\n' \
- '%(type)s* %(name)s_data = (%(type)s*) %(name)s->data;\n'
- code = templ % locals()
- return code
+ void numpy_check_size(PyArrayObject* arr_obj, int Ndims, const char* name)
+ {
+ if (arr_obj->nd != Ndims)
+ {
+ char msg[500];
+ sprintf(msg,"received '%d' dimensional array instead of '%d' dimensional array for variable '%s'",
+ arr_obj->nd,Ndims,name);
+ throw_error(PyExc_TypeError,msg);
+ }
+ }
+};
- def standard_decl_code(self):
- type = numeric_to_c_type_mapping[self.numeric_type]
- name = self.name
- templ = '// %(name)s array declaration\n' \
- 'PyArrayObject* %(name)s = convert_to_numpy(py_%(name)s,"%(name)s");\n' \
- 'conversion_numpy_check_type(%(name)s,py_type<%(type)s>::code,"%(name)s");\n' \
- 'int* _N%(name)s = %(name)s->dimensions;\n' \
- 'int* _S%(name)s = %(name)s->strides;\n' \
- 'int _D%(name)s = %(name)s->nd;\n' \
- '%(type)s* %(name)s_data = (%(type)s*) %(name)s->data;\n'
- code = templ % locals()
- return code
- #def c_function_declaration_code(self):
- # """
- # This doesn't pass the size through. That info is gonna have to
- # be redone in the c function.
- # """
- # templ_dict = {}
- # templ_dict['type'] = numeric_to_c_type_mapping[self.numeric_type]
- # templ_dict['dims'] = self.dims
- # templ_dict['name'] = self.name
- # code = 'blitz::Array<%(type)s,%(dims)d> &%(name)s' % templ_dict
- # return code
-
- def local_dict_code(self):
- code = '// for now, array "%s" is not returned as arryas are edited' \
- ' in place (should this change?)\n' % (self.name)
- return code
+numpy_size_handler x__numpy_size_handler = numpy_size_handler();
+#define conversion_numpy_check_size x__numpy_size_handler.conversion_numpy_check_size
+#define numpy_check_size x__numpy_size_handler.numpy_check_size
- def cleanup_code(self):
- # could use Py_DECREF here I think and save NULL test.
- code = "Py_XDECREF(py_%s);\n" % self.name
- return code
+"""
- def __repr__(self):
- msg = "(array:: name: %s, type: %s)" % \
- (self.name, self.numeric_type)
- return msg
+numeric_init_code = \
+"""
+Py_Initialize();
+import_array();
+PyImport_ImportModule("Numeric");
+"""
+
+class array_converter(common_base_converter):
- def __cmp__(self,other):
- #only works for equal
- return cmp(self.name,other.name) or \
- cmp(self.numeric_type,other.numeric_type) or \
- cmp(self.dims, other.dims) or \
- cmp(self.__class__, other.__class__)
+ def init_info(self):
+ common_base_converter.init_info(self)
+ self.type_name = 'numpy'
+ self.check_func = 'PyArray_Check'
+ self.c_type = 'PyArrayObject*'
+ self.to_c_return = '(PyArrayObject*) py_obj'
+ self.matching_types = [ArrayType]
+ self.headers = ['"Numeric/arrayobject.h"','<complex>','<math.h>']
+ self.support_code = [size_check_code, type_check_code]
+ self.module_init_code = [numeric_init_code]
+
+ def get_var_type(self,value):
+ return value.typecode()
+
+ def template_vars(self,inline=0):
+ res = common_base_converter.template_vars(self,inline)
+ if hasattr(self,'var_type'):
+ res['num_type'] = num_to_c_types[self.var_type]
+ res['num_typecode'] = num_typecode[self.var_type]
+ res['array_name'] = self.name + "_array"
+ return res
+
+ def declaration_code(self,templatize = 0,inline=0):
+ code = '%(py_var)s = %(var_lookup)s;\n' \
+ '%(c_type)s %(array_name)s = %(var_convert)s;\n' \
+ 'conversion_numpy_check_type(%(array_name)s,%(num_typecode)s,"%(name)s");\n' \
+ 'int* N%(name)s = %(array_name)s->dimensions;\n' \
+ 'int* S%(name)s = %(array_name)s->strides;\n' \
+ 'int D%(name)s = %(array_name)s->nd;\n' \
+ '%(num_type)s* %(name)s = (%(num_type)s*) %(array_name)s->data;\n'
+ code = code % self.template_vars(inline=inline)
+ return code
def test(level=10):
from scipy_base.testing import module_test