summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/dispatching.c38
-rw-r--r--numpy/core/src/umath/dispatching.h3
-rw-r--r--numpy/core/src/umath/ufunc_object.c7
-rw-r--r--numpy/core/tests/test_ufunc.py5
4 files changed, 47 insertions, 6 deletions
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index 07292be5a..2504220e7 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -779,6 +779,14 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
* only work with DType (classes/types). This is because it has to ensure
* that legacy (value-based promotion) is used when necessary.
*
+ * NOTE: The machinery here currently ignores output arguments unless
+ * they are part of the signature. This slightly limits unsafe loop
+ * specializations, which is important for the `ensure_reduce_compatible`
+ * fallback mode.
+ * To fix this, the caching mechanism (and dispatching) can be extended.
+ * When/if that happens, the `ensure_reduce_compatible` could be
+ * deprecated (it should never kick in because promotion kick in first).
+ *
* @param ufunc The ufunc object, used mainly for the fallback.
* @param ops The array operands (used only for the fallback).
* @param signature As input, the DType signature fixed explicitly by the user.
@@ -791,6 +799,13 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
* these including clearing the output.
* @param force_legacy_promotion If set, we have to use the old type resolution
* to implement value-based promotion/casting.
+ * @param ensure_reduce_compatible Must be set for reductions, in which case
+ * the found implementation is checked for reduce-like compatibility.
+ * If it is *not* compatible and `signature[2] != NULL`, we assume its
+ * output DType is correct (see NOTE above).
+ * If removed, promotion may require information about whether this
+ * is a reduction, so the more likely case is to always keep fixing this
+ * when necessary, but push down the handling so it can be cached.
*/
NPY_NO_EXPORT PyArrayMethodObject *
promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
@@ -798,7 +813,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool force_legacy_promotion,
- npy_bool allow_legacy_promotion)
+ npy_bool allow_legacy_promotion,
+ npy_bool ensure_reduce_compatible)
{
int nin = ufunc->nin, nargs = ufunc->nargs;
@@ -852,8 +868,26 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArrayMethodObject *method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
- /* Fill `signature` with final DTypes used by the ArrayMethod/inner-loop */
+ /*
+ * In certain cases (only the logical ufuncs really), the loop we found may
+ * not be reduce-compatible. Since the machinery can't distinguish a
+ * reduction with an output from a normal ufunc call, we have to assume
+ * the result DType is correct and force it for the input (if not forced
+ * already).
+ * NOTE: This does assume that all loops are "safe" see the NOTE in this
+ * comment. That could be relaxed, in which case we may need to
+ * cache if a call was for a reduction.
+ */
PyObject *all_dtypes = PyTuple_GET_ITEM(info, 0);
+ if (ensure_reduce_compatible && signature[0] == NULL &&
+ PyTuple_GET_ITEM(all_dtypes, 0) != PyTuple_GET_ITEM(all_dtypes, 2)) {
+ signature[0] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM(all_dtypes, 2);
+ Py_INCREF(signature[0]);
+ return promote_and_get_ufuncimpl(ufunc,
+ ops, signature, op_dtypes,
+ force_legacy_promotion, allow_legacy_promotion, NPY_FALSE);
+ }
+
for (int i = 0; i < nargs; i++) {
if (signature[i] == NULL) {
signature[i] = (PyArray_DTypeMeta *)PyTuple_GET_ITEM(all_dtypes, i);
diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h
index 305b0549f..a7e9e88d0 100644
--- a/numpy/core/src/umath/dispatching.h
+++ b/numpy/core/src/umath/dispatching.h
@@ -20,7 +20,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool force_legacy_promotion,
- npy_bool allow_legacy_promotion);
+ npy_bool allow_legacy_promotion,
+ npy_bool ensure_reduce_compatible);
NPY_NO_EXPORT PyObject *
add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc,
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 3f8f61431..97dc9f00f 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -2741,7 +2741,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
}
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
- ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE);
+ ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE, NPY_TRUE);
/* Output can currently get cleared, others XDECREF in case of error */
Py_XDECREF(operation_DTypes[1]);
if (out != NULL) {
@@ -4855,7 +4855,8 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
*/
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
operands, signature,
- operand_DTypes, force_legacy_promotion, allow_legacy_promotion);
+ operand_DTypes, force_legacy_promotion, allow_legacy_promotion,
+ NPY_FALSE);
if (ufuncimpl == NULL) {
goto fail;
}
@@ -6021,7 +6022,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
operands, signature, operand_DTypes,
- force_legacy_promotion, allow_legacy_promotion);
+ force_legacy_promotion, allow_legacy_promotion, NPY_FALSE);
if (ufuncimpl == NULL) {
goto fail;
}
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 3b9b8ebc3..500904586 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -2132,6 +2132,11 @@ class TestUfunc:
assert_array_equal(ufunc(a, c, out=out), expected)
out = np.zeros((), dtype=np.int32)
assert ufunc.reduce(a, out=out) == True
+ # Last check, test reduction when out and a match (the complexity here
+ # is that the "i,i->?" may seem right, but should not match.
+ a = np.array([3], dtype="i")
+ out = np.zeros((), dtype=a.dtype)
+ assert ufunc.reduce(a, out=out) == 1
@pytest.mark.parametrize("ufunc",
[np.logical_and, np.logical_or, np.logical_xor])