summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.h4
-rw-r--r--numpy/core/src/umath/ufunc_object.c209
2 files changed, 173 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h
index e0745c4ea..6e17e3fd7 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.h
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h
@@ -373,10 +373,10 @@ PyArray_TransferStridedToNDim(npy_intp ndim,
npy_intp size2 = PyArray_SIZE(arr2); \
npy_intp size3 = PyArray_SIZE(arr3); \
count = size1 > size2 ? size1 : size2; \
- count = size3 > count ? size3 : count; \
+ if (size3 > count) count = size3; \
data1 = PyArray_BYTES(arr1); \
data2 = PyArray_BYTES(arr2); \
- data2 = PyArray_BYTES(arr3); \
+ data3 = PyArray_BYTES(arr3); \
stride1 = (size1 == 1 ? 0 : \
(PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \
PyArray_STRIDE(arr1, 0) : \
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index c42a71f9e..a7a0d307d 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -3506,13 +3506,15 @@ trivial_two_operand_loop(PyArrayObject **op,
void *innerloopdata)
{
char *data[2];
- npy_intp count, stride[2];
+ npy_intp count[2], stride[2];
PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(op[0], op[1],
- count,
+ count[0],
data[0], data[1],
stride[0], stride[1]);
- innerloop(data, &count, stride, innerloopdata);
+ count[1] = count[0];
+ //printf ("loop count %d\n", (int)count[0]);
+ innerloop(data, count, stride, innerloopdata);
}
static void
@@ -3521,13 +3523,97 @@ trivial_three_operand_loop(PyArrayObject **op,
void *innerloopdata)
{
char *data[3];
- npy_intp count, stride[3];
+ npy_intp count[3], stride[3];
PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(op[0], op[1], op[2],
- count,
+ count[0],
data[0], data[1], data[2],
stride[0], stride[1], stride[2]);
- innerloop(data, &count, stride, innerloopdata);
+ count[1] = count[0];
+ count[2] = count[0];
+ //printf ("loop count %d\n", (int)count[0]);
+ innerloop(data, count, stride, innerloopdata);
+}
+
+static int
+iterator_loop(npy_intp nin, npy_intp nout,
+ PyArrayObject **op,
+ PyArray_Descr **dtype,
+ NPY_ORDER order,
+ NPY_CASTING casting,
+ npy_intp buffersize,
+ PyUFuncGenericFunction innerloop,
+ void *innerloopdata)
+{
+ npy_intp i, niter = nin + nout;
+ npy_uint32 op_flags[NPY_MAXARGS];
+ NpyIter *iter;
+
+ NpyIter_IterNext_Fn iternext;
+ char **dataptr;
+ npy_intp *stride;
+ npy_intp *count_ptr;
+ npy_intp copied_counts[NPY_MAXARGS];
+
+ PyArrayObject **op_it;
+
+ /* Set up the flags */
+ for (i = 0; i < nin; ++i) {
+ op_flags[i] = NPY_ITER_READONLY|
+ NPY_ITER_ALIGNED;
+ }
+ for (i = nin; i < niter; ++i) {
+ op_flags[i] = NPY_ITER_WRITEONLY|
+ NPY_ITER_ALIGNED|
+ NPY_ITER_ALLOCATE|
+ NPY_ITER_NO_BROADCAST|
+ NPY_ITER_NO_SUBTYPE;
+ }
+
+ /* Allocate the iterator */
+ iter = NpyIter_MultiNew(niter, op,
+ NPY_ITER_NO_INNER_ITERATION|
+ NPY_ITER_BUFFERED|
+ NPY_ITER_GROWINNER,
+ order, casting,
+ op_flags, dtype,
+ 0, NULL, buffersize);
+ if (iter == NULL) {
+ return NPY_FAIL;
+ }
+
+ /* Prepare for the loop */
+ iternext = NpyIter_GetIterNext(iter, NULL);
+ if (iternext == NULL) {
+ NpyIter_Deallocate(iter);
+ return NPY_FAIL;
+ }
+ dataptr = NpyIter_GetDataPtrArray(iter);
+ stride = NpyIter_GetInnerStrideArray(iter);
+ count_ptr = NpyIter_GetInnerLoopSizePtr(iter);
+
+ /* Execute the loop */
+ do {
+ /* Copy the loop count a few times... :( */
+ npy_intp count = *count_ptr;
+ for (i = 0; i < niter; ++i) {
+ copied_counts[i] = count;
+ }
+ //printf ("loop count %d\n", (int)count);
+ innerloop(dataptr, count_ptr, stride, innerloopdata);
+ } while (iternext(iter));
+
+ /* Copy any allocated outputs */
+ op_it = NpyIter_GetObjectArray(iter);
+ for (i = nin; i < niter; ++i) {
+ if (op[i] == NULL) {
+ op[i] = op_it[i];
+ Py_INCREF(op[i]);
+ }
+ }
+
+ NpyIter_Deallocate(iter);
+ return NPY_SUCCEED;
}
static PyObject *
@@ -3535,7 +3621,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
{
npy_intp nargs, nin = self->nin, nout = self->nout;
npy_intp i, j, niter = nin + nout;
- PyObject *obj, *context;
+ PyObject *obj, *context, *ret = NULL;
char *ufunc_name;
/* This contains the all the inputs and outputs */
PyArrayObject *op[NPY_MAXARGS];
@@ -3544,7 +3630,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyUFuncGenericFunction innerloop = NULL;
void *innerloopdata = NULL;
- int no_castable_output, all_inputs_scalar, all_dtypes_match;
+ int no_castable_output, all_inputs_scalar, trivial_loop_ok = 0;
/* TODO: For 1.6, the default should probably be NPY_CORDER */
NPY_ORDER order = NPY_KEEPORDER;
@@ -3592,7 +3678,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
op[i] = (PyArrayObject *)PyArray_FromAny(obj, NULL, 0, 0, 0, context);
Py_XDECREF(context);
if (op[i] == NULL) {
- goto fail;
+ goto finish;
}
/* If there's a non-scalar input, indicate so */
if (PyArray_NDIM(op[i]) > 0) {
@@ -3612,7 +3698,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
if (!PyArray_ISWRITEABLE(obj)) {
PyErr_SetString(PyExc_ValueError,
"return array is not writeable");
- goto fail;
+ goto finish;
}
Py_INCREF(obj);
op[i] = (PyArrayObject *)obj;
@@ -3621,7 +3707,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyErr_SetString(PyExc_TypeError,
"return arrays must be "
"of ArrayType");
- goto fail;
+ goto finish;
}
}
@@ -3639,14 +3725,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
if (PyString_AsStringAndSize(key, &str, &length) == -1) {
PyErr_SetString(PyExc_TypeError, "invalid keyword argument");
- goto fail;
+ goto finish;
}
switch (str[0]) {
case 'c':
if (strncmp(str,"casting",7) == 0) {
if (!PyArray_CastingConverter(value, &casting)) {
- goto fail;
+ goto finish;
}
bad_arg = 0;
}
@@ -3663,14 +3749,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyErr_SetString(PyExc_ValueError,
"cannot specify 'out' as both a positional "
"and keyword argument");
- goto fail;
+ goto finish;
}
else {
if (PyArray_Check(value)) {
if (!PyArray_ISWRITEABLE(value)) {
PyErr_SetString(PyExc_ValueError,
"return array is not writeable");
- goto fail;
+ goto finish;
}
Py_INCREF(value);
op[nin] = (PyArrayObject *)value;
@@ -3679,14 +3765,14 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyErr_SetString(PyExc_TypeError,
"return arrays must be "
"of ArrayType");
- goto fail;
+ goto finish;
}
}
}
else if (strncmp(str,"order",5) == 0) {
if (!PyArray_OrderConverter(value, &order)) {
- goto fail;
+ goto finish;
}
bad_arg = 0;
}
@@ -3702,7 +3788,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
if (bad_arg) {
char *format = "'%s' is an invalid keyword to %s";
PyErr_Format(PyExc_TypeError, format, str, ufunc_name);
- goto fail;
+ goto finish;
}
}
}
@@ -3767,6 +3853,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
for (j = nin; j < niter; ++j) {
if (op[j] != NULL) {
PyArray_Descr *tmp = PyArray_DescrFromType(types[j]);
+ if (tmp == NULL) {
+ goto finish;
+ }
if (!PyArray_CanCastTypeTo(tmp, PyArray_DESCR(op[j]),
output_casting)) {
Py_DECREF(tmp);
@@ -3779,14 +3868,22 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
}
/* If all the tests passed, we found our function */
if (j == niter) {
- all_dtypes_match = 1;
+ trivial_loop_ok = 1;
/* Fill the dtypes array */
for (j = 0; j < niter; ++j) {
dtype[j] = PyArray_DescrFromType(types[j]);
- /* If the dtype doesn't match, indicate so */
- if (all_dtypes_match && op[j] != NULL &&
- !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j]))) {
- all_dtypes_match = 0;
+ if (dtype[j] == NULL) {
+ goto finish;
+ }
+ /*
+ * If the dtype doesn't match, or the array isn't aligned,
+ * indicate that the trivial loop can't be done.
+ */
+ if (trivial_loop_ok && op[j] != NULL &&
+ (!PyArray_ISALIGNED(op[j]) ||
+ !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j]))
+ )) {
+ trivial_loop_ok = 0;
}
}
/* Save the inner loop and its data */
@@ -3813,10 +3910,10 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
"inputs could not be safely coerced to any supported "
"types");
}
- return NULL;
+ goto finish;
}
-
+#if 0
printf("input types:\n");
for (i = 0; i < nin; ++i) {
PyObject_Print((PyObject *)dtype[i], stdout, 0);
@@ -3828,9 +3925,10 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
printf(" ");
}
printf("\n");
+#endif
/* First check for the trivial cases that don't need an iterator */
- if (all_dtypes_match && niter == 2 && op[1] == NULL &&
+ if (trivial_loop_ok && niter == 2 && op[1] == NULL &&
(order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
PyArray_TRIVIALLY_ITERABLE(op[0])) {
Py_INCREF(dtype[1]);
@@ -3841,18 +3939,18 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
NULL, NULL,
PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0,
NULL);
- printf("trivial 1 input with allocated output\n");
+ //printf("trivial 1 input with allocated output\n");
trivial_two_operand_loop(op, innerloop, innerloopdata);
}
- else if (all_dtypes_match && niter == 2 && op[1] != NULL &&
+ else if (trivial_loop_ok && niter == 2 && op[1] != NULL &&
PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) &&
PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
- printf("trivial 1 input\n");
+ //printf("trivial 1 input\n");
trivial_two_operand_loop(op, innerloop, innerloopdata);
}
- else if (all_dtypes_match && nin == 2 && nout == 1 && op[2] == NULL &&
+ else if (trivial_loop_ok && nin == 2 && nout == 1 && op[2] == NULL &&
(order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
PyArrayObject *tmp;
@@ -3874,32 +3972,67 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
NULL, NULL,
PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0,
NULL);
- printf("trivial 2 input with allocated output\n");
+ //printf("trivial 2 input with allocated output\n");
trivial_three_operand_loop(op, innerloop, innerloopdata);
}
- else if (all_dtypes_match && niter == 3 &&
+ else if (trivial_loop_ok && niter == 3 &&
op[1] != NULL && op[2] != NULL &&
PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) &&
PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) &&
PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) {
- printf("trivial 2 input\n");
+ //printf("trivial 2 input\n");
trivial_three_operand_loop(op, innerloop, innerloopdata);
}
/* Otherwise an iterator is required to resolve broadcasting, etc */
else {
- printf("iterator case\n");
+ npy_intp buffersize = 0;
+
+ //printf("iterator loop\n");
+
+ if (iterator_loop(nin, nout, op, dtype, order,
+ casting, buffersize, innerloop, innerloopdata)
+ != NPY_SUCCEED) {
+ goto finish;
+ }
}
+#if 0
+ for (i = 0; i < nin; ++i) {
+ printf("input %d: (%p)\n", (int)i, op[i]);
+ PyObject_Print((PyObject *)op[i], stdout, 0);
+ printf("\n");
+ }
+ for (i = nin; i < niter; ++i) {
+ printf("output %d: (%p)\n", (int)i, op[i]);
+ PyObject_Print((PyObject *)op[i], stdout, 0);
+ printf("\n");
+ }
+ printf("\n");
+#endif
- PyErr_SetString(PyExc_RuntimeError, "implementation is not finished!");
-fail:
+ if (nout == 1) {
+ ret = op[nin];
+ Py_INCREF(ret);
+ }
+ else {
+ ret = PyTuple_New(nout);
+ if (ret == NULL) {
+ goto finish;
+ }
+ for (i = 0; i < nout; ++i) {
+ PyTuple_SET_ITEM(ret, i, op[nin+i]);
+ Py_INCREF(op[nin+i]);
+ }
+ }
+
+finish:
for (i = 0; i < niter; ++i) {
Py_XDECREF(op[i]);
Py_XDECREF(dtype[i]);
}
- return NULL;
+ return ret;
}
static PyObject *
@@ -4595,7 +4728,7 @@ static struct PyMethodDef ufunc_methods[] = {
{"outer",
(PyCFunction)ufunc_outer,
METH_VARARGS | METH_KEYWORDS, NULL},
- {"gennew", /* new generic call */
+ {"f", /* new generic call */
(PyCFunction)ufunc_generic_call_iter,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */