summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/ctors.c3
-rw-r--r--numpy/core/src/umath/ufunc_object.c67
2 files changed, 60 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 41f6b5834..7802b8114 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -2365,7 +2365,8 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src)
}
}
- if (PyArray_TRIVIALLY_ITERABLE_PAIR(dst, src)) {
+ if (PyArray_NDIM(dst) >= PyArray_NDIM(src) &&
+ PyArray_TRIVIALLY_ITERABLE_PAIR(dst, src)) {
char *dst_data, *src_data;
npy_intp count, dst_stride, src_stride, src_itemsize;
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 1f78fb6f6..c42a71f9e 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -3500,6 +3500,36 @@ _find_array_wrap(PyObject *args, PyObject **output_wrap, int nin, int nout)
return;
}
+static void
+trivial_two_operand_loop(PyArrayObject **op,
+ PyUFuncGenericFunction innerloop,
+ void *innerloopdata)
+{
+ char *data[2];
+ npy_intp count, stride[2];
+
+ PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(op[0], op[1],
+ count,
+ data[0], data[1],
+ stride[0], stride[1]);
+ innerloop(data, &count, stride, innerloopdata);
+}
+
+static void
+trivial_three_operand_loop(PyArrayObject **op,
+ PyUFuncGenericFunction innerloop,
+ void *innerloopdata)
+{
+ char *data[3];
+ npy_intp count, stride[3];
+
+ PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(op[0], op[1], op[2],
+ count,
+ data[0], data[1], data[2],
+ stride[0], stride[1], stride[2]);
+ innerloop(data, &count, stride, innerloopdata);
+}
+
static PyObject *
ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
{
@@ -3514,7 +3544,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyUFuncGenericFunction innerloop = NULL;
void *innerloopdata = NULL;
- int no_castable_output, all_inputs_scalar;
+ int no_castable_output, all_inputs_scalar, all_dtypes_match;
/* TODO: For 1.6, the default should probably be NPY_CORDER */
NPY_ORDER order = NPY_KEEPORDER;
@@ -3749,9 +3779,15 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
}
/* If all the tests passed, we found our function */
if (j == niter) {
+ all_dtypes_match = 1;
/* Fill the dtypes array */
for (j = 0; j < niter; ++j) {
dtype[j] = PyArray_DescrFromType(types[j]);
+ /* If the dtype doesn't match, indicate so */
+ if (all_dtypes_match && op[j] != NULL &&
+ !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j]))) {
+ all_dtypes_match = 0;
+ }
}
/* Save the inner loop and its data */
innerloop = self->functions[i];
@@ -3763,11 +3799,15 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
if (i == self->ntypes) {
if (no_castable_output) {
PyErr_Format(PyExc_TypeError,
- "function output could not be coerced to a provided "
+ "function output could not be coerced to provided "
"output parameter according to the specified casting "
"rule");
}
else {
+ /*
+ * TODO: We should try again if the casting rule is same_kind
+ * or unsafe, and look for a function more liberally.
+ */
PyErr_Format(PyExc_TypeError,
"function not supported for the input types, and the "
"inputs could not be safely coerced to any supported "
@@ -3790,7 +3830,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
printf("\n");
/* First check for the trivial cases that don't need an iterator */
- if (niter == 2 && op[1] == NULL &&
+ if (all_dtypes_match && niter == 2 && op[1] == NULL &&
(order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
PyArray_TRIVIALLY_ITERABLE(op[0])) {
Py_INCREF(dtype[1]);
@@ -3802,13 +3842,17 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0,
NULL);
printf("trivial 1 input with allocated output\n");
+
+ trivial_two_operand_loop(op, innerloop, innerloopdata);
}
- else if (niter == 2 && op[1] != NULL &&
+ else if (all_dtypes_match && niter == 2 && op[1] != NULL &&
PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) &&
PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
printf("trivial 1 input\n");
+
+ trivial_two_operand_loop(op, innerloop, innerloopdata);
}
- else if (nin == 2 && nout == 1 && op[2] == NULL &&
+ else if (all_dtypes_match && nin == 2 && nout == 1 && op[2] == NULL &&
(order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
PyArrayObject *tmp;
@@ -3831,12 +3875,17 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0,
NULL);
printf("trivial 2 input with allocated output\n");
+
+ trivial_three_operand_loop(op, innerloop, innerloopdata);
}
- else if (niter == 3 && op[1] != NULL && op[2] != NULL &&
- PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) &&
- PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) &&
- PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) {
+ else if (all_dtypes_match && niter == 3 &&
+ op[1] != NULL && op[2] != NULL &&
+ PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) &&
+ PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) &&
+ PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) {
printf("trivial 2 input\n");
+
+ trivial_three_operand_loop(op, innerloop, innerloopdata);
}
/* Otherwise an iterator is required to resolve broadcasting, etc */
else {