diff options
| -rw-r--r-- | numpy/core/src/umath/dispatching.c | 18 | ||||
| -rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 22 | ||||
| -rw-r--r-- | numpy/core/tests/test_ufunc.py | 6 |
3 files changed, 26 insertions, 20 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index 772e46d64..07292be5a 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -53,6 +53,9 @@ #include "ufunc_type_resolution.h" +#define PROMOTION_DEBUG_TRACING 0 + + /* forward declaration */ static NPY_INLINE PyObject * promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, @@ -185,6 +188,17 @@ resolve_implementation_info(PyUFuncObject *ufunc, PyObject *best_dtypes = NULL; PyObject *best_resolver_info = NULL; +#if PROMOTION_DEBUG_TRACING + printf("Promoting for '%s' promoters only: %d\n", + ufunc->name ? ufunc->name : "<unknown>", (int)only_promoters); + printf(" DTypes: "); + PyObject *tmp = PyArray_TupleFromItems(ufunc->nargs, op_dtypes, 1); + PyObject_Print(tmp, stdout, 0); + Py_DECREF(tmp); + printf("\n"); + Py_DECREF(tmp); +#endif + for (Py_ssize_t res_idx = 0; res_idx < size; res_idx++) { /* Test all resolvers */ PyObject *resolver_info = PySequence_Fast_GET_ITEM( @@ -774,7 +788,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, * either by the `signature` or by an `operand`. * (outputs and the second input can be NULL for reductions). * NOTE: In some cases, the promotion machinery may currently modify - * these. + * these including clearing the output. * @param force_legacy_promotion If set, we have to use the old type resolution * to implement value-based promotion/casting. */ @@ -802,7 +816,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc, Py_XSETREF(op_dtypes[i], signature[i]); assert(i >= ufunc->nin || !NPY_DT_is_abstract(signature[i])); } - else if (i > nin) { + else if (i >= nin) { /* * We currently just ignore outputs if not in signature, this will * always give the/a correct result (limits registering specialized diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 5863a2b83..3f8f61431 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -2738,29 +2738,15 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, Py_INCREF(operation_DTypes[0]); operation_DTypes[2] = operation_DTypes[0]; Py_INCREF(operation_DTypes[2]); - /* - * We have to force the dtype, because otherwise value based casting - * may ignore the request entirely. This means that `out=` the same - * as `dtype=`, which should probably not be forced normally, so we - * only do it for legacy DTypes... - */ - if (NPY_DT_is_legacy(operation_DTypes[0]) - && NPY_DT_is_legacy(operation_DTypes[1])) { - if (signature[0] == NULL) { - Py_INCREF(operation_DTypes[0]); - signature[0] = operation_DTypes[0]; - Py_INCREF(operation_DTypes[0]); - signature[2] = operation_DTypes[0]; - } - } } PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc, ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE); - Py_DECREF(operation_DTypes[1]); + /* Output can currently get cleared, others XDECREF in case of error */ + Py_XDECREF(operation_DTypes[1]); if (out != NULL) { - Py_DECREF(operation_DTypes[0]); - Py_DECREF(operation_DTypes[2]); + Py_XDECREF(operation_DTypes[0]); + Py_XDECREF(operation_DTypes[2]); } if (ufuncimpl == NULL) { return NULL; diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index d2bbbc181..3b9b8ebc3 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -2126,6 +2126,12 @@ class TestUfunc: c = np.array([1., 2.]) assert_array_equal(ufunc(a, c), ufunc([True, True], True)) assert ufunc.reduce(a) == True + # check that the output has no effect: + out = np.zeros(2, dtype=np.int32) + expected = ufunc([True, True], True).astype(out.dtype) + assert_array_equal(ufunc(a, c, out=out), expected) + out = np.zeros((), dtype=np.int32) + assert ufunc.reduce(a, out=out) == True @pytest.mark.parametrize("ufunc", [np.logical_and, np.logical_or, np.logical_xor]) |
