diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/umath/dispatching.c | 38 | ||||
| -rw-r--r-- | numpy/core/src/umath/dispatching.h | 3 | ||||
| -rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 7 | ||||
| -rw-r--r-- | numpy/core/tests/test_ufunc.py | 5 |
4 files changed, 47 insertions, 6 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index 07292be5a..2504220e7 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -779,6 +779,14 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, * only work with DType (classes/types). This is because it has to ensure * that legacy (value-based promotion) is used when necessary. * + * NOTE: The machinery here currently ignores output arguments unless + * they are part of the signature. This slightly limits unsafe loop + * specializations, which is important for the `ensure_reduce_compatible` + * fallback mode. + * To fix this, the caching mechanism (and dispatching) can be extended. + * When/if that happens, the `ensure_reduce_compatible` could be + * deprecated (it should never kick in because promotion kick in first). + * * @param ufunc The ufunc object, used mainly for the fallback. * @param ops The array operands (used only for the fallback). * @param signature As input, the DType signature fixed explicitly by the user. @@ -791,6 +799,13 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, * these including clearing the output. * @param force_legacy_promotion If set, we have to use the old type resolution * to implement value-based promotion/casting. + * @param ensure_reduce_compatible Must be set for reductions, in which case + * the found implementation is checked for reduce-like compatibility. + * If it is *not* compatible and `signature[2] != NULL`, we assume its + * output DType is correct (see NOTE above). + * If removed, promotion may require information about whether this + * is a reduction, so the more likely case is to always keep fixing this + * when necessary, but push down the handling so it can be cached. */ NPY_NO_EXPORT PyArrayMethodObject * promote_and_get_ufuncimpl(PyUFuncObject *ufunc, @@ -798,7 +813,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc, PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *op_dtypes[], npy_bool force_legacy_promotion, - npy_bool allow_legacy_promotion) + npy_bool allow_legacy_promotion, + npy_bool ensure_reduce_compatible) { int nin = ufunc->nin, nargs = ufunc->nargs; @@ -852,8 +868,26 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc, PyArrayMethodObject *method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1); - /* Fill `signature` with final DTypes used by the ArrayMethod/inner-loop */ + /* + * In certain cases (only the logical ufuncs really), the loop we found may + * not be reduce-compatible. Since the machinery can't distinguish a + * reduction with an output from a normal ufunc call, we have to assume + * the result DType is correct and force it for the input (if not forced + * already). + * NOTE: This does assume that all loops are "safe" see the NOTE in this + * comment. That could be relaxed, in which case we may need to + * cache if a call was for a reduction. + */ PyObject *all_dtypes = PyTuple_GET_ITEM(info, 0); + if (ensure_reduce_compatible && signature[0] == NULL && + PyTuple_GET_ITEM(all_dtypes, 0) != PyTuple_GET_ITEM(all_dtypes, 2)) { + signature[0] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM(all_dtypes, 2); + Py_INCREF(signature[0]); + return promote_and_get_ufuncimpl(ufunc, + ops, signature, op_dtypes, + force_legacy_promotion, allow_legacy_promotion, NPY_FALSE); + } + for (int i = 0; i < nargs; i++) { if (signature[i] == NULL) { signature[i] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM(all_dtypes, i); diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h index 305b0549f..a7e9e88d0 100644 --- a/numpy/core/src/umath/dispatching.h +++ b/numpy/core/src/umath/dispatching.h @@ -20,7 +20,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc, PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *op_dtypes[], npy_bool force_legacy_promotion, - npy_bool allow_legacy_promotion); + npy_bool allow_legacy_promotion, + npy_bool ensure_reduce_compatible); NPY_NO_EXPORT PyObject * add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc, diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 3f8f61431..97dc9f00f 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -2741,7 +2741,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, } PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc, - ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE); + ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE, NPY_TRUE); /* Output can currently get cleared, others XDECREF in case of error */ Py_XDECREF(operation_DTypes[1]); if (out != NULL) { @@ -4855,7 +4855,8 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc, */ PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc, operands, signature, - operand_DTypes, force_legacy_promotion, allow_legacy_promotion); + operand_DTypes, force_legacy_promotion, allow_legacy_promotion, + NPY_FALSE); if (ufuncimpl == NULL) { goto fail; } @@ -6021,7 +6022,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc, operands, signature, operand_DTypes, - force_legacy_promotion, allow_legacy_promotion); + force_legacy_promotion, allow_legacy_promotion, NPY_FALSE); if (ufuncimpl == NULL) { goto fail; } diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 3b9b8ebc3..500904586 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -2132,6 +2132,11 @@ class TestUfunc: assert_array_equal(ufunc(a, c, out=out), expected) out = np.zeros((), dtype=np.int32) assert ufunc.reduce(a, out=out) == True + # Last check, test reduction when out and a match (the complexity here + # is that the "i,i->?" may seem right, but should not match. + a = np.array([3], dtype="i") + out = np.zeros((), dtype=a.dtype) + assert ufunc.reduce(a, out=out) == 1 @pytest.mark.parametrize("ufunc", [np.logical_and, np.logical_or, np.logical_xor]) |
