diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-09-11 23:22:54 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-09-11 23:22:54 +0000 |
commit | c4608ed1c831143e63c5df338c1c64acfa8a5343 (patch) | |
tree | 111b8a344fc530a390a8ea0a9e295366765ca71e | |
parent | 6686ee2a647f71d68cc6c85173ec7d701f366293 (diff) | |
download | numpy-c4608ed1c831143e63c5df338c1c64acfa8a5343.tar.gz |
Improve the getting and setting of ufunc loops for user-defined types.
-rw-r--r-- | numpy/core/include/numpy/ufuncobject.h | 10 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 1 | ||||
-rw-r--r-- | numpy/core/src/ufuncobject.c | 265 |
3 files changed, 192 insertions, 84 deletions
diff --git a/numpy/core/include/numpy/ufuncobject.h b/numpy/core/include/numpy/ufuncobject.h index 544e1d68c..735d1ea1c 100644 --- a/numpy/core/include/numpy/ufuncobject.h +++ b/numpy/core/include/numpy/ufuncobject.h @@ -187,6 +187,16 @@ typedef struct { PyObject *callable; } PyUFunc_PyFuncData; +/* A linked-list of function information for + user-defined 1-d loops. + */ +typedef struct _loop1d_info { + PyUFuncGenericFunction func; + void *data; + int *arg_types; + struct _loop1d_info *next; +} PyUFunc_Loop1d; + #include "__ufunc_api.h" diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 4f2f697a1..ec3cd88d2 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -1731,7 +1731,6 @@ PyArray_CanCoerceScalar(int thistype, int neededtype, } - /*OBJECT_API*/ static PyArrayObject ** PyArray_ConvertToCommonType(PyObject *op, int *retn) diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index 56b31ac8b..56b95cff1 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -616,7 +616,49 @@ _lowest_type(char intype) } } +static char *_types_msg = "function not supported for these types, " \ + "and can't coerce safely to supported types"; + +/* Called for non-NULL user-defined functions. + The object should be a CObject pointing to a linked-list of functions + storing the function, data, and signature of all user-defined functions. + There must be a match with the input argument types or an error + will occur. + */ +static int +_find_matching_userloop(PyObject *obj, int *arg_types, + PyArray_SCALARKIND *scalars, + PyUFuncGenericFunction *function, void **data, + int nargs) +{ + PyUFunc_Loop1d *funcdata; + int i; + funcdata = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(obj); + while (funcdata != NULL) { + for (i=0; i<nargs; i++) { + if (!PyArray_CanCoerceScalar(arg_types[i], + funcdata->arg_types[i], + scalars[i])) + break; + } + if (i==nargs) { /* match found */ + *function = funcdata->func; + *data = funcdata->data; + /* Make sure actual arg_types supported + by the loop are used */ + for (i=0; i<nargs; i++) { + arg_types[i] = funcdata->arg_types[i]; + } + return 0; + } + funcdata = funcdata->next; + } + PyErr_SetString(PyExc_TypeError, _types_msg); + return -1; +} + /* Called to determine coercion + Can change arg_types. */ static int @@ -639,8 +681,7 @@ select_types(PyUFuncObject *self, int *arg_types, if (userdef > 0) { PyObject *key, *obj; - int *this_types=NULL; - + int ret; obj = NULL; key = PyInt_FromLong((long) userdef); if (key == NULL) return -1; @@ -652,37 +693,13 @@ select_types(PyUFuncObject *self, int *arg_types, " with no registered loops"); return -1; } - if PyTuple_Check(obj) { - PyObject *item; - *function = (PyUFuncGenericFunction) \ - PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj,0)); - item = PyTuple_GET_ITEM(obj, 2); - if (PyCObject_Check(item)) { - *data = PyCObject_AsVoidPtr(item); - } - item = PyTuple_GET_ITEM(obj, 1); - if (PyCObject_Check(item)) { - this_types = PyCObject_AsVoidPtr(item); - } - } - else { - *function = (PyUFuncGenericFunction) \ - PyCObject_AsVoidPtr(obj); - *data = NULL; - } - - if (this_types == NULL) { - for (i=1; i<self->nargs; i++) { - arg_types[i] = userdef; - } - } - else { - for (i=1; i<self->nargs; i++) { - arg_types[i] = this_types[i]; - } - } - Py_DECREF(obj); - return 0; + /* extract the correct function + data and argtypes + */ + ret = _find_matching_userloop(obj, arg_types, scalars, + function, data, self->nargs); + Py_DECREF(obj); + return ret; } start_type = arg_types[0]; @@ -707,9 +724,7 @@ select_types(PyUFuncObject *self, int *arg_types, if (j == self->nin) break; } if(i>=self->ntypes) { - PyErr_SetString(PyExc_TypeError, - "function not supported for these types, "\ - "and can't coerce safely to supported types"); + PyErr_SetString(PyExc_TypeError, _types_msg); return -1; } for(j=0; j<self->nargs; j++) @@ -876,7 +891,7 @@ _has_reflected_op(PyObject *op, char *name) #undef _GETATTR_ static int -construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps) +construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps) { int nargs, i, maxsize; int arg_types[NPY_MAXARGS]; @@ -963,7 +978,8 @@ construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps) } } - /* Create copies for some of the arrays if appropriate */ + /* Create copies for some of the arrays if they are small + enough and not already contiguous */ if (_create_copies(loop, arg_types, mps) < 0) return -1; /* Create Iterators for the Inputs */ @@ -1344,8 +1360,8 @@ construct_loop(PyUFuncObject *self, PyObject *args, PyArrayObject **mps) &(loop->bufsize), &(loop->errormask), &(loop->errobj)) < 0) goto fail; - /* Setup the matrices */ - if (construct_matrices(loop, args, mps) < 0) goto fail; + /* Setup the arrays */ + if (construct_arrays(loop, args, mps) < 0) goto fail; PyUFunc_clearfperr(); @@ -1421,7 +1437,7 @@ _printcastbuf(PyUFuncLoopObject *loop, int bufnum) /* This generic function is called with the ufunc object, the arguments to it, and an array of (pointers to) PyArrayObjects which are NULL. The - arguments are parsed and placed in mps in construct_loop (construct_matrices) + arguments are parsed and placed in mps in construct_loop (construct_arrays) */ /*UFUNC_API*/ @@ -3087,6 +3103,45 @@ PyUFunc_FromFuncAndData(PyUFuncGenericFunction *func, void **data, return (PyObject *)self; } +typedef struct { + PyObject_HEAD + void *c_obj; +} _simple_cobj; + +#define _SETCPTR(cobj, val) ((_simple_cobj *)(cobj))->c_obj = (val) + +/* return 1 if arg1 > arg2, 0 if arg1 == arg2, and -1 if arg1 < arg2 + */ +static int +cmp_arg_types(int *arg1, int *arg2, int n) +{ + while (n--) { + if (*arg1 > *arg2) + return 1; + else if (*arg1 < *arg2) + return -1; + arg1++; arg2++; + } + return 0; +} + +/* This frees the linked-list structure + when the CObject is destroyed (removed + from the internal dictionary) +*/ +static void +_loop1d_list_free(void *ptr) +{ + PyUFunc_Loop1d *funcdata; + if (ptr == NULL) return; + funcdata = (PyUFunc_Loop1d *)ptr; + if (funcdata == NULL) return; + _pya_free(funcdata->arg_types); + _loop1d_list_free(funcdata->next); + _pya_free(funcdata); +} + + /*UFUNC_API*/ static int PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, @@ -3096,8 +3151,11 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, void *data) { PyArray_Descr *descr; + PyUFunc_Loop1d *funcdata; PyObject *key, *cobj; - int ret; + int i; + int *newtypes=NULL; + descr=PyArray_DescrFromType(usertype); if ((usertype < PyArray_USERDEF) || (descr==NULL)) { @@ -3112,49 +3170,90 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc, } key = PyInt_FromLong((long) usertype); if (key == NULL) return -1; - cobj = PyCObject_FromVoidPtr((void *)function, NULL); - if (cobj == NULL) {Py_DECREF(key); return -1;} - if (data == NULL && arg_types == NULL) { - ret = PyDict_SetItem(ufunc->userloops, key, cobj); - Py_DECREF(cobj); - Py_DECREF(key); - return ret; - } - else { - PyObject *cobj2, *cobj3, *tmp; - if (arg_types == NULL) { - cobj2 = Py_None; - Py_INCREF(cobj2); - } - else { - cobj2 = PyCObject_FromVoidPtr((void *)arg_types, NULL); - if (cobj2 == NULL) { - Py_DECREF(cobj); - Py_DECREF(key); - return -1; - } - } - if (data == NULL) { - cobj3 = Py_None; - Py_INCREF(cobj3); - } - else { - cobj3 = PyCObject_FromVoidPtr(data, NULL); - if (cobj3 == NULL) { - Py_DECREF(cobj2); - Py_DECREF(cobj); - Py_DECREF(key); - return -1; - } - } - tmp=Py_BuildValue("NNN", cobj, cobj2, cobj3); - ret = PyDict_SetItem(ufunc->userloops, key, tmp); - Py_DECREF(tmp); - Py_DECREF(key); - return ret; - } + funcdata = _pya_malloc(sizeof(PyUFunc_Loop1d)); + if (funcdata == NULL) goto fail; + newtypes = _pya_malloc(sizeof(int)*ufunc->nargs); + if (newtypes == NULL) goto fail; + if (arg_types != NULL) { + for (i=0; i<ufunc->nargs; i++) { + newtypes[i] = arg_types[i]; + } + } + else { + for (i=0; i<ufunc->nargs; i++) { + newtypes[i] = usertype; + } + } + + funcdata->func = function; + funcdata->arg_types = newtypes; + funcdata->data = data; + funcdata->next = NULL; + + /* Get entry for this user-defined type*/ + cobj = PyDict_GetItem(ufunc->userloops, key); + + /* If it's not there, then make one and return. */ + if (cobj == NULL) { + cobj = PyCObject_FromVoidPtr((void *)function, + _loop1d_list_free); + if (cobj == NULL) goto fail; + PyDict_SetItem(ufunc->userloops, key, cobj); + Py_DECREF(cobj); + Py_DECREF(key); + return 0; + } + else { + PyUFunc_Loop1d *current, *prev=NULL; + int cmp; + /* There is already at least 1 loop. Place this one in + lexicographic order. If the next one signature + is exactly like this one, then just replace. + Otherwise insert. + */ + current = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(cobj); + while (current != NULL) { + cmp = cmp_arg_types(current->arg_types, newtypes, + ufunc->nargs); + if (cmp >= 0) break; + prev = current; + current = current->next; + } + if (cmp == 0) { /* just replace it with new function */ + current->func = function; + current->data = data; + _pya_free(newtypes); + _pya_free(funcdata); + } + else { /* insert it before the current one + by hacking the internals of cobject to + replace the function pointer --- + can't use API because destructor is set. + */ + funcdata->next = current; + if (prev == NULL) { /* place this at front */ + _SETCPTR(cobj, funcdata); + } + else { + prev->next = funcdata; + } + } + } + Py_DECREF(key); + return 0; + + + fail: + Py_DECREF(key); + _pya_free(funcdata); + _pya_free(newtypes); + if (!PyErr_Occurred()) PyErr_NoMemory(); + return -1; } +#undef _SETCPTR + + static void ufunc_dealloc(PyUFuncObject *self) { |