diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2021-06-28 18:44:55 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2021-06-28 18:45:39 -0500 |
commit | b98c51660bc1a01a7b974b9fb841ad2bec9cbf3a (patch) | |
tree | 06931a6db0efb00f4291f23919db5491b1032195 /numpy/core | |
parent | 164682330d70ffe84e706a67955435b76e1cf19d (diff) | |
download | numpy-b98c51660bc1a01a7b974b9fb841ad2bec9cbf3a.tar.gz |
MAINT: Use cast-is-view flag for the ufunc trivial-loop check
This also would avoid duplicating work if we pull in the casting
safety check to here, since checking for equivalent types also
has to get the cast safety internally.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 49 |
2 files changed, 31 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index cd0e21098..b8f24394c 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -393,6 +393,7 @@ _get_cast_safety_from_castingimpl(PyArrayMethodObject *castingimpl, Py_DECREF(out_descrs[1]); /* NPY_NO_CASTING has to be used for (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW) */ assert(casting != (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW)); + assert(casting != NPY_NO_CASTING); /* forgot to set the view flag */ return casting; } diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 067de6990..e94b3be4a 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -49,6 +49,7 @@ #include "common.h" #include "dtypemeta.h" #include "numpyos.h" +#include "convert_datatype.h" /********** PRINTF DEBUG TRACING **************/ #define NPY_UF_DBG_TRACING 0 @@ -991,33 +992,45 @@ fail: */ static int check_for_trivial_loop(PyUFuncObject *ufunc, - PyArrayObject **op, - PyArray_Descr **dtype, - npy_intp buffersize) + PyArrayObject **op, PyArray_Descr **dtypes, + npy_intp buffersize) { - npy_intp i, nin = ufunc->nin, nop = nin + ufunc->nout; + int i, nin = ufunc->nin, nop = nin + ufunc->nout; for (i = 0; i < nop; ++i) { /* * If the dtype doesn't match, or the array isn't aligned, * indicate that the trivial loop can't be done. */ - if (op[i] != NULL && - (!PyArray_ISALIGNED(op[i]) || - !PyArray_EquivTypes(dtype[i], PyArray_DESCR(op[i])) - )) { + if (op[i] == NULL) { + continue; + } + int must_copy = !PyArray_ISALIGNED(op[i]); + + if (dtypes[i] != PyArray_DESCR(op[i])) { + NPY_CASTING safety = PyArray_GetCastSafety( + PyArray_DESCR(op[i]), dtypes[i], NULL); + if (safety < 0) { + /* A proper error during a cast check should be rare */ + return -1; + } + if (!(safety & _NPY_CAST_IS_VIEW)) { + must_copy = 1; + } + } + if (must_copy) { /* * If op[j] is a scalar or small one dimensional * array input, make a copy to keep the opportunity - * for a trivial loop. + * for a trivial loop. Outputs are not copied here. */ - if (i < nin && (PyArray_NDIM(op[i]) == 0 || - (PyArray_NDIM(op[i]) == 1 && - PyArray_DIM(op[i],0) <= buffersize))) { + if (i < nin && (PyArray_NDIM(op[i]) == 0 || ( + PyArray_NDIM(op[i]) == 1 && + PyArray_DIM(op[i], 0) <= buffersize))) { PyArrayObject *tmp; - Py_INCREF(dtype[i]); + Py_INCREF(dtypes[i]); tmp = (PyArrayObject *) - PyArray_CastToType(op[i], dtype[i], 0); + PyArray_CastToType(op[i], dtypes[i], 0); if (tmp == NULL) { return -1; } @@ -2476,8 +2489,6 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc, /* These parameters come from extobj= or from a TLS global */ int buffersize = 0, errormask = 0; - int trivial_loop_ok = 0; - NPY_UF_DBG_PRINT1("\nEvaluating ufunc %s\n", ufunc_name); /* Get the buffersize and errormask */ @@ -2532,16 +2543,16 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc, * Since it requires dtypes, it can only be called after * ufunc->type_resolver */ - trivial_loop_ok = check_for_trivial_loop(ufunc, + int trivial_ok = check_for_trivial_loop(ufunc, op, operation_descrs, buffersize); - if (trivial_loop_ok < 0) { + if (trivial_ok < 0) { return -1; } /* check_for_trivial_loop on half-floats can overflow */ npy_clear_floatstatus_barrier((char*)&ufunc); - retval = execute_legacy_ufunc_loop(ufunc, trivial_loop_ok, + retval = execute_legacy_ufunc_loop(ufunc, trivial_ok, op, operation_descrs, order, buffersize, output_array_prepare, full_args, op_flags); |