diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/ctors.c | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 67 |
2 files changed, 60 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 41f6b5834..7802b8114 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -2365,7 +2365,8 @@ PyArray_CopyInto(PyArrayObject *dst, PyArrayObject *src) } } - if (PyArray_TRIVIALLY_ITERABLE_PAIR(dst, src)) { + if (PyArray_NDIM(dst) >= PyArray_NDIM(src) && + PyArray_TRIVIALLY_ITERABLE_PAIR(dst, src)) { char *dst_data, *src_data; npy_intp count, dst_stride, src_stride, src_itemsize; diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 1f78fb6f6..c42a71f9e 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -3500,6 +3500,36 @@ _find_array_wrap(PyObject *args, PyObject **output_wrap, int nin, int nout) return; } +static void +trivial_two_operand_loop(PyArrayObject **op, + PyUFuncGenericFunction innerloop, + void *innerloopdata) +{ + char *data[2]; + npy_intp count, stride[2]; + + PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(op[0], op[1], + count, + data[0], data[1], + stride[0], stride[1]); + innerloop(data, &count, stride, innerloopdata); +} + +static void +trivial_three_operand_loop(PyArrayObject **op, + PyUFuncGenericFunction innerloop, + void *innerloopdata) +{ + char *data[3]; + npy_intp count, stride[3]; + + PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(op[0], op[1], op[2], + count, + data[0], data[1], data[2], + stride[0], stride[1], stride[2]); + innerloop(data, &count, stride, innerloopdata); +} + static PyObject * ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) { @@ -3514,7 +3544,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyUFuncGenericFunction innerloop = NULL; void *innerloopdata = NULL; - int no_castable_output, all_inputs_scalar; + int no_castable_output, all_inputs_scalar, all_dtypes_match; /* TODO: For 1.6, the default should probably be NPY_CORDER */ NPY_ORDER order = NPY_KEEPORDER; @@ -3749,9 +3779,15 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) } /* If all the tests passed, we found our function */ if (j == niter) { + all_dtypes_match = 1; /* Fill the dtypes array */ for (j = 0; j < niter; ++j) { dtype[j] = PyArray_DescrFromType(types[j]); + /* If the dtype doesn't match, indicate so */ + if (all_dtypes_match && op[j] != NULL && + !PyArray_EquivTypes(dtype[j], PyArray_DESCR(op[j]))) { + all_dtypes_match = 0; + } } /* Save the inner loop and its data */ innerloop = self->functions[i]; @@ -3763,11 +3799,15 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) if (i == self->ntypes) { if (no_castable_output) { PyErr_Format(PyExc_TypeError, - "function output could not be coerced to a provided " + "function output could not be coerced to provided " "output parameter according to the specified casting " "rule"); } else { + /* + * TODO: We should try again if the casting rule is same_kind + * or unsafe, and look for a function more liberally. + */ PyErr_Format(PyExc_TypeError, "function not supported for the input types, and the " "inputs could not be safely coerced to any supported " @@ -3790,7 +3830,7 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) printf("\n"); /* First check for the trivial cases that don't need an iterator */ - if (niter == 2 && op[1] == NULL && + if (all_dtypes_match && niter == 2 && op[1] == NULL && (order == NPY_ANYORDER || order == NPY_KEEPORDER) && PyArray_TRIVIALLY_ITERABLE(op[0])) { Py_INCREF(dtype[1]); @@ -3802,13 +3842,17 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0, NULL); printf("trivial 1 input with allocated output\n"); + + trivial_two_operand_loop(op, innerloop, innerloopdata); } - else if (niter == 2 && op[1] != NULL && + else if (all_dtypes_match && niter == 2 && op[1] != NULL && PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) && PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { printf("trivial 1 input\n"); + + trivial_two_operand_loop(op, innerloop, innerloopdata); } - else if (nin == 2 && nout == 1 && op[2] == NULL && + else if (all_dtypes_match && nin == 2 && nout == 1 && op[2] == NULL && (order == NPY_ANYORDER || order == NPY_KEEPORDER) && PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { PyArrayObject *tmp; @@ -3831,12 +3875,17 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0, NULL); printf("trivial 2 input with allocated output\n"); + + trivial_three_operand_loop(op, innerloop, innerloopdata); } - else if (niter == 3 && op[1] != NULL && op[2] != NULL && - PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) && - PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) && - PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) { + else if (all_dtypes_match && niter == 3 && + op[1] != NULL && op[2] != NULL && + PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) && + PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) && + PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) { printf("trivial 2 input\n"); + + trivial_three_operand_loop(op, innerloop, innerloopdata); } /* Otherwise an iterator is required to resolve broadcasting, etc */ else { |