diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 49 |
2 files changed, 31 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index cd0e21098..b8f24394c 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -393,6 +393,7 @@ _get_cast_safety_from_castingimpl(PyArrayMethodObject *castingimpl, Py_DECREF(out_descrs[1]); /* NPY_NO_CASTING has to be used for (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW) */ assert(casting != (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW)); + assert(casting != NPY_NO_CASTING); /* forgot to set the view flag */ return casting; } diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 067de6990..e94b3be4a 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -49,6 +49,7 @@ #include "common.h" #include "dtypemeta.h" #include "numpyos.h" +#include "convert_datatype.h" /********** PRINTF DEBUG TRACING **************/ #define NPY_UF_DBG_TRACING 0 @@ -991,33 +992,45 @@ fail: */ static int check_for_trivial_loop(PyUFuncObject *ufunc, - PyArrayObject **op, - PyArray_Descr **dtype, - npy_intp buffersize) + PyArrayObject **op, PyArray_Descr **dtypes, + npy_intp buffersize) { - npy_intp i, nin = ufunc->nin, nop = nin + ufunc->nout; + int i, nin = ufunc->nin, nop = nin + ufunc->nout; for (i = 0; i < nop; ++i) { /* * If the dtype doesn't match, or the array isn't aligned, * indicate that the trivial loop can't be done. */ - if (op[i] != NULL && - (!PyArray_ISALIGNED(op[i]) || - !PyArray_EquivTypes(dtype[i], PyArray_DESCR(op[i])) - )) { + if (op[i] == NULL) { + continue; + } + int must_copy = !PyArray_ISALIGNED(op[i]); + + if (dtypes[i] != PyArray_DESCR(op[i])) { + NPY_CASTING safety = PyArray_GetCastSafety( + PyArray_DESCR(op[i]), dtypes[i], NULL); + if (safety < 0) { + /* A proper error during a cast check should be rare */ + return -1; + } + if (!(safety & _NPY_CAST_IS_VIEW)) { + must_copy = 1; + } + } + if (must_copy) { /* * If op[j] is a scalar or small one dimensional * array input, make a copy to keep the opportunity - * for a trivial loop. + * for a trivial loop. Outputs are not copied here. */ - if (i < nin && (PyArray_NDIM(op[i]) == 0 || - (PyArray_NDIM(op[i]) == 1 && - PyArray_DIM(op[i],0) <= buffersize))) { + if (i < nin && (PyArray_NDIM(op[i]) == 0 || ( + PyArray_NDIM(op[i]) == 1 && + PyArray_DIM(op[i], 0) <= buffersize))) { PyArrayObject *tmp; - Py_INCREF(dtype[i]); + Py_INCREF(dtypes[i]); tmp = (PyArrayObject *) - PyArray_CastToType(op[i], dtype[i], 0); + PyArray_CastToType(op[i], dtypes[i], 0); if (tmp == NULL) { return -1; } @@ -2476,8 +2489,6 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc, /* These parameters come from extobj= or from a TLS global */ int buffersize = 0, errormask = 0; - int trivial_loop_ok = 0; - NPY_UF_DBG_PRINT1("\nEvaluating ufunc %s\n", ufunc_name); /* Get the buffersize and errormask */ @@ -2532,16 +2543,16 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc, * Since it requires dtypes, it can only be called after * ufunc->type_resolver */ - trivial_loop_ok = check_for_trivial_loop(ufunc, + int trivial_ok = check_for_trivial_loop(ufunc, op, operation_descrs, buffersize); - if (trivial_loop_ok < 0) { + if (trivial_ok < 0) { return -1; } /* check_for_trivial_loop on half-floats can overflow */ npy_clear_floatstatus_barrier((char*)&ufunc); - retval = execute_legacy_ufunc_loop(ufunc, trivial_loop_ok, + retval = execute_legacy_ufunc_loop(ufunc, trivial_ok, op, operation_descrs, order, buffersize, output_array_prepare, full_args, op_flags); |