summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c1
-rw-r--r--numpy/core/src/umath/ufunc_object.c49
2 files changed, 31 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index cd0e21098..b8f24394c 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -393,6 +393,7 @@ _get_cast_safety_from_castingimpl(PyArrayMethodObject *castingimpl,
Py_DECREF(out_descrs[1]);
/* NPY_NO_CASTING has to be used for (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW) */
assert(casting != (NPY_EQUIV_CASTING|_NPY_CAST_IS_VIEW));
+ assert(casting != NPY_NO_CASTING); /* forgot to set the view flag */
return casting;
}
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 067de6990..e94b3be4a 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -49,6 +49,7 @@
#include "common.h"
#include "dtypemeta.h"
#include "numpyos.h"
+#include "convert_datatype.h"
/********** PRINTF DEBUG TRACING **************/
#define NPY_UF_DBG_TRACING 0
@@ -991,33 +992,45 @@ fail:
*/
static int
check_for_trivial_loop(PyUFuncObject *ufunc,
- PyArrayObject **op,
- PyArray_Descr **dtype,
- npy_intp buffersize)
+ PyArrayObject **op, PyArray_Descr **dtypes,
+ npy_intp buffersize)
{
- npy_intp i, nin = ufunc->nin, nop = nin + ufunc->nout;
+ int i, nin = ufunc->nin, nop = nin + ufunc->nout;
for (i = 0; i < nop; ++i) {
/*
* If the dtype doesn't match, or the array isn't aligned,
* indicate that the trivial loop can't be done.
*/
- if (op[i] != NULL &&
- (!PyArray_ISALIGNED(op[i]) ||
- !PyArray_EquivTypes(dtype[i], PyArray_DESCR(op[i]))
- )) {
+ if (op[i] == NULL) {
+ continue;
+ }
+ int must_copy = !PyArray_ISALIGNED(op[i]);
+
+ if (dtypes[i] != PyArray_DESCR(op[i])) {
+ NPY_CASTING safety = PyArray_GetCastSafety(
+ PyArray_DESCR(op[i]), dtypes[i], NULL);
+ if (safety < 0) {
+ /* A proper error during a cast check should be rare */
+ return -1;
+ }
+ if (!(safety & _NPY_CAST_IS_VIEW)) {
+ must_copy = 1;
+ }
+ }
+ if (must_copy) {
/*
* If op[j] is a scalar or small one dimensional
* array input, make a copy to keep the opportunity
- * for a trivial loop.
+ * for a trivial loop. Outputs are not copied here.
*/
- if (i < nin && (PyArray_NDIM(op[i]) == 0 ||
- (PyArray_NDIM(op[i]) == 1 &&
- PyArray_DIM(op[i],0) <= buffersize))) {
+ if (i < nin && (PyArray_NDIM(op[i]) == 0 || (
+ PyArray_NDIM(op[i]) == 1 &&
+ PyArray_DIM(op[i], 0) <= buffersize))) {
PyArrayObject *tmp;
- Py_INCREF(dtype[i]);
+ Py_INCREF(dtypes[i]);
tmp = (PyArrayObject *)
- PyArray_CastToType(op[i], dtype[i], 0);
+ PyArray_CastToType(op[i], dtypes[i], 0);
if (tmp == NULL) {
return -1;
}
@@ -2476,8 +2489,6 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc,
/* These parameters come from extobj= or from a TLS global */
int buffersize = 0, errormask = 0;
- int trivial_loop_ok = 0;
-
NPY_UF_DBG_PRINT1("\nEvaluating ufunc %s\n", ufunc_name);
/* Get the buffersize and errormask */
@@ -2532,16 +2543,16 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc,
* Since it requires dtypes, it can only be called after
* ufunc->type_resolver
*/
- trivial_loop_ok = check_for_trivial_loop(ufunc,
+ int trivial_ok = check_for_trivial_loop(ufunc,
op, operation_descrs, buffersize);
- if (trivial_loop_ok < 0) {
+ if (trivial_ok < 0) {
return -1;
}
/* check_for_trivial_loop on half-floats can overflow */
npy_clear_floatstatus_barrier((char*)&ufunc);
- retval = execute_legacy_ufunc_loop(ufunc, trivial_loop_ok,
+ retval = execute_legacy_ufunc_loop(ufunc, trivial_ok,
op, operation_descrs, order,
buffersize, output_array_prepare,
full_args, op_flags);