diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-20 00:29:12 -0800 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-20 00:29:12 -0800 |
commit | fa08ee2832e68f91a4ea38aa8d1d55dbb143cfd4 (patch) | |
tree | 6b53141afca75cb2cc95c49a1d3502e609bd5cde /numpy | |
parent | 58f33f8805f2aebf78e977dbe89298ad33e1060d (diff) | |
download | numpy-fa08ee2832e68f91a4ea38aa8d1d55dbb143cfd4.tar.gz |
ENH: ufunc: Simple iterator-based ufunc execution is working
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.h | 4 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 209 |
2 files changed, 173 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h index e0745c4ea..6e17e3fd7 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.h +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h @@ -373,10 +373,10 @@ PyArray_TransferStridedToNDim(npy_intp ndim, npy_intp size2 = PyArray_SIZE(arr2); \ npy_intp size3 = PyArray_SIZE(arr3); \ count = size1 > size2 ? size1 : size2; \ - count = size3 > count ? size3 : count; \ + if (size3 > count) count = size3; \ data1 = PyArray_BYTES(arr1); \ data2 = PyArray_BYTES(arr2); \ - data2 = PyArray_BYTES(arr3); \ + data3 = PyArray_BYTES(arr3); \ stride1 = (size1 == 1 ? 0 : \ (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \ PyArray_STRIDE(arr1, 0) : \ diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index c42a71f9e..a7a0d307d 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -3506,13 +3506,15 @@ trivial_two_operand_loop(PyArrayObject **op, void *innerloopdata) { char *data[2]; - npy_intp count, stride[2]; + npy_intp count[2], stride[2]; PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(op[0], op[1], - count, + count[0], data[0], data[1], stride[0], stride[1]); - innerloop(data, &count, stride, innerloopdata); + count[1] = count[0]; + //printf ("loop count %d\n", (int)count[0]); + innerloop(data, count, stride, innerloopdata); } static void @@ -3521,13 +3523,97 @@ trivial_three_operand_loop(PyArrayObject **op, void *innerloopdata) { char *data[3]; - npy_intp count, stride[3]; + npy_intp count[3], stride[3]; PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(op[0], op[1], op[2], - count, + count[0], data[0], data[1], data[2], stride[0], stride[1], stride[2]); - innerloop(data, &count, stride, innerloopdata); + count[1] = count[0]; + count[2] = count[0]; + //printf ("loop count %d\n", (int)count[0]); + innerloop(data, count, stride, innerloopdata); +} + +static int +iterator_loop(npy_intp nin, npy_intp nout, + PyArrayObject **op, + PyArray_Descr **dtype, + NPY_ORDER order, + NPY_CASTING casting, + npy_intp buffersize, + PyUFuncGenericFunction innerloop, + void *innerloopdata) +{ + npy_intp i, niter = nin + nout; + npy_uint32 op_flags[NPY_MAXARGS]; + NpyIter *iter; + + NpyIter_IterNext_Fn iternext; + char **dataptr; + npy_intp *stride; + npy_intp *count_ptr; + npy_intp copied_counts[NPY_MAXARGS]; + + PyArrayObject **op_it; + + /* Set up the flags */ + for (i = 0; i < nin; ++i) { + op_flags[i] = NPY_ITER_READONLY| + NPY_ITER_ALIGNED; + } + for (i = nin; i < niter; ++i) { + op_flags[i] = NPY_ITER_WRITEONLY| + NPY_ITER_ALIGNED| + NPY_ITER_ALLOCATE| + NPY_ITER_NO_BROADCAST| + NPY_ITER_NO_SUBTYPE; + } + + /* Allocate the iterator */ + iter = NpyIter_MultiNew(niter, op, + NPY_ITER_NO_INNER_ITERATION| + NPY_ITER_BUFFERED| + NPY_ITER_GROWINNER, + order, casting, + op_flags, dtype, + 0, NULL, buffersize); + if (iter == NULL) { + return NPY_FAIL; + } + + /* Prepare for the loop */ + iternext = NpyIter_GetIterNext(iter, NULL); + if (iternext == NULL) { + NpyIter_Deallocate(iter); + return NPY_FAIL; + } + dataptr = NpyIter_GetDataPtrArray(iter); + stride = NpyIter_GetInnerStrideArray(iter); + count_ptr = NpyIter_GetInnerLoopSizePtr(iter); + + /* Execute the loop */ + do { + /* Copy the loop count a few times... :( */ + npy_intp count = *count_ptr; + for (i = 0; i < niter; ++i) { + copied_counts[i] = count; + } + //printf ("loop count %d\n", (int)count); + innerloop(dataptr, count_ptr, stride, innerloopdata); + } while (iternext(iter)); + + /* Copy any allocated outputs */ + op_it = NpyIter_GetObjectArray(iter); + for (i = nin; i < niter; ++i) { + if (op[i] == NULL) { + op[i] = op_it[i]; + Py_INCREF(op[i]); + } + } + + NpyIter_Deallocate(iter); + return NPY_SUCCEED; } static PyObject * @@ -3535,7 +3621,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) { npy_intp nargs, nin = self->nin, nout = self->nout; npy_intp i, j, niter = nin + nout; - PyObject *obj, *context; + PyObject *obj, *context, *ret = NULL; char *ufunc_name; /* This contains the all the inputs and outputs */ PyArrayObject *op[NPY_MAXARGS]; @@ -3544,7 +3630,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyUFuncGenericFunction innerloop = NULL; void *innerloopdata = NULL; - int no_castable_output, all_inputs_scalar, all_dtypes_match; + int no_castable_output, all_inputs_scalar, trivial_loop_ok = 0; /* TODO: For 1.6, the default should probably be NPY_CORDER */ NPY_ORDER order = NPY_KEEPORDER; @@ -3592,7 +3678,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) op[i] = (PyArrayObject *)PyArray_FromAny(obj, NULL, 0, 0, 0, context); Py_XDECREF(context); if (op[i] == NULL) { - goto fail; + goto finish; } /* If there's a non-scalar input, indicate so */ if (PyArray_NDIM(op[i]) > 0) { @@ -3612,7 +3698,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) if (!PyArray_ISWRITEABLE(obj)) { PyErr_SetString(PyExc_ValueError, "return array is not writeable"); - goto fail; + goto finish; } Py_INCREF(obj); op[i] = (PyArrayObject *)obj; @@ -3621,7 +3707,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_TypeError, "return arrays must be " "of ArrayType"); - goto fail; + goto finish; } } @@ -3639,14 +3725,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) if (PyString_AsStringAndSize(key, &str, &length) == -1) { PyErr_SetString(PyExc_TypeError, "invalid keyword argument"); - goto fail; + goto finish; } switch (str[0]) { case 'c': if (strncmp(str,"casting",7) == 0) { if (!PyArray_CastingConverter(value, &casting)) { - goto fail; + goto finish; } bad_arg = 0; } @@ -3663,14 +3749,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_ValueError, "cannot specify 'out' as both a positional " "and keyword argument"); - goto fail; + goto finish; } else { if (PyArray_Check(value)) { if (!PyArray_ISWRITEABLE(value)) { PyErr_SetString(PyExc_ValueError, "return array is not writeable"); - goto fail; + goto finish; } Py_INCREF(value); op[nin] = (PyArrayObject *)value; @@ -3679,14 +3765,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_TypeError, "return arrays must be " "of ArrayType"); - goto fail; + goto finish; } } } else if (strncmp(str,"order",5) == 0) { if (!PyArray_OrderConverter(value, &order)) { - goto fail; + goto finish; } bad_arg = 0; } @@ -3702,7 +3788,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) if (bad_arg) { char *format = "'%s' is an invalid keyword to %s"; PyErr_Format(PyExc_TypeError, format, str, ufunc_name); - goto fail; + goto finish; } } } @@ -3767,6 +3853,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) for (j = nin; j < niter; ++j) { if (op[j] != NULL) { PyArray_Descr *tmp = PyArray_DescrFromType(types[j]); + if (tmp == NULL) { + goto finish; + } if (!PyArray_CanCastTypeTo(tmp, PyArray_DESCR(op[j]), output_casting)) { Py_DECREF(tmp); @@ -3779,14 +3868,22 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) } /* If all the tests passed, we found our function */ if (j == niter) { - all_dtypes_match = 1; + trivial_loop_ok = 1; /* Fill the dtypes array */ for (j = 0; j < niter; ++j) { dtype[j] = PyArray_DescrFromType(types[j]); - /* If the dtype doesn't match, indicate so */ - if (all_dtypes_match && op[j] != NULL && - !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j]))) { - all_dtypes_match = 0; + if (dtype[j] == NULL) { + goto finish; + } + /* + * If the dtype doesn't match, or the array isn't aligned, + * indicate that the trivial loop can't be done. + */ + if (trivial_loop_ok && op[j] != NULL && + (!PyArray_ISALIGNED(op[j]) || + !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j])) + )) { + trivial_loop_ok = 0; } } /* Save the inner loop and its data */ @@ -3813,10 +3910,10 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) "inputs could not be safely coerced to any supported " "types"); } - return NULL; + goto finish; } - +#if 0 printf("input types:\n"); for (i = 0; i < nin; ++i) { PyObject_Print((PyObject *)dtype[i], stdout, 0); @@ -3828,9 +3925,10 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) printf(" "); } printf("\n"); +#endif /* First check for the trivial cases that don't need an iterator */ - if (all_dtypes_match && niter == 2 && op[1] == NULL && + if (trivial_loop_ok && niter == 2 && op[1] == NULL && (order == NPY_ANYORDER || order == NPY_KEEPORDER) && PyArray_TRIVIALLY_ITERABLE(op[0])) { Py_INCREF(dtype[1]); @@ -3841,18 +3939,18 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) NULL, NULL, PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0, NULL); - printf("trivial 1 input with allocated output\n"); + //printf("trivial 1 input with allocated output\n"); trivial_two_operand_loop(op, innerloop, innerloopdata); } - else if (all_dtypes_match && niter == 2 && op[1] != NULL && + else if (trivial_loop_ok && niter == 2 && op[1] != NULL && PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) && PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { - printf("trivial 1 input\n"); + //printf("trivial 1 input\n"); trivial_two_operand_loop(op, innerloop, innerloopdata); } - else if (all_dtypes_match && nin == 2 && nout == 1 && op[2] == NULL && + else if (trivial_loop_ok && nin == 2 && nout == 1 && op[2] == NULL && (order == NPY_ANYORDER || order == NPY_KEEPORDER) && PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { PyArrayObject *tmp; @@ -3874,32 +3972,67 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) NULL, NULL, PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0, NULL); - printf("trivial 2 input with allocated output\n"); + //printf("trivial 2 input with allocated output\n"); trivial_three_operand_loop(op, innerloop, innerloopdata); } - else if (all_dtypes_match && niter == 3 && + else if (trivial_loop_ok && niter == 3 && op[1] != NULL && op[2] != NULL && PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) && PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) && PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) { - printf("trivial 2 input\n"); + //printf("trivial 2 input\n"); trivial_three_operand_loop(op, innerloop, innerloopdata); } /* Otherwise an iterator is required to resolve broadcasting, etc */ else { - printf("iterator case\n"); + npy_intp buffersize = 0; + + //printf("iterator loop\n"); + + if (iterator_loop(nin, nout, op, dtype, order, + casting, buffersize, innerloop, innerloopdata) + != NPY_SUCCEED) { + goto finish; + } } +#if 0 + for (i = 0; i < nin; ++i) { + printf("input %d: (%p)\n", (int)i, op[i]); + PyObject_Print((PyObject *)op[i], stdout, 0); + printf("\n"); + } + for (i = nin; i < niter; ++i) { + printf("output %d: (%p)\n", (int)i, op[i]); + PyObject_Print((PyObject *)op[i], stdout, 0); + printf("\n"); + } + printf("\n"); +#endif - PyErr_SetString(PyExc_RuntimeError, "implementation is not finished!"); -fail: + if (nout == 1) { + ret = op[nin]; + Py_INCREF(ret); + } + else { + ret = PyTuple_New(nout); + if (ret == NULL) { + goto finish; + } + for (i = 0; i < nout; ++i) { + PyTuple_SET_ITEM(ret, i, op[nin+i]); + Py_INCREF(op[nin+i]); + } + } + +finish: for (i = 0; i < niter; ++i) { Py_XDECREF(op[i]); Py_XDECREF(dtype[i]); } - return NULL; + return ret; } static PyObject * @@ -4595,7 +4728,7 @@ static struct PyMethodDef ufunc_methods[] = { {"outer", (PyCFunction)ufunc_outer, METH_VARARGS | METH_KEYWORDS, NULL}, - {"gennew", /* new generic call */ + {"f", /* new generic call */ (PyCFunction)ufunc_generic_call_iter, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} /* sentinel */ |