summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-02-02 18:26:38 -0600
committerSebastian Berg <sebastianb@nvidia.com>2023-01-20 15:28:48 +0100
commit581bb3796a97c2eec634a390e9c7befb5917859f (patch)
treee0723fa907e90701a9ae0a1bb848ecba65659bf2 /numpy
parentffe18c971d4d122dc36cafcc6651b44118d08d39 (diff)
downloadnumpy-581bb3796a97c2eec634a390e9c7befb5917859f.tar.gz
ENH: Support identity-function in experimental DType API
Also add it to the wrapped array-method (ufunc) implementation so that a Unit dtype can reasonably use an identity for reductions.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/include/numpy/experimental_dtype_api.h32
-rw-r--r--numpy/core/src/multiarray/array_method.c1
-rw-r--r--numpy/core/src/umath/wrapping_array_method.c33
3 files changed, 66 insertions, 0 deletions
diff --git a/numpy/core/include/numpy/experimental_dtype_api.h b/numpy/core/include/numpy/experimental_dtype_api.h
index 5f5f24317..48941baec 100644
--- a/numpy/core/include/numpy/experimental_dtype_api.h
+++ b/numpy/core/include/numpy/experimental_dtype_api.h
@@ -321,6 +321,38 @@ typedef int (PyArrayMethod_StridedLoop)(PyArrayMethod_Context *context,
NpyAuxData *transferdata);
+/*
+ * For reductions, NumPy sometimes requires an identity or default value.
+ * The typical way of relying on a single "default" for ufuncs does not always
+ * work however. This function allows customizing the identity value as well
+ * as whether the operation is "reorderable".
+ */
+#define NPY_METH_get_identity 7
+
+typedef enum {
+ /* The value can be used as a default for empty reductions */
+ NPY_METH_ITEM_IS_DEFAULT = 1 << 0,
+ /* The value represents the identity value */
+ NPY_METH_ITEM_IS_IDENTITY = 1 << 1,
+ /* The operation is fully reorderable (iteration order may be optimized) */
+ NPY_METH_IS_REORDERABLE = 1 << 2,
+} NPY_ARRAYMETHOD_IDENTITY_FLAGS;
+
+/*
+ * If an identity exists, should set the `NPY_METH_ITEM_IS_IDENTITY`, normally
+ * the `NPY_METH_ITEM_IS_DEFAULT` should also be set, but it is distinct.
+ * By default NumPy provides a "default" for `object` dtype, but does not use
+ * it as an identity.
+ * The `NPY_METH_IS_REORDERABLE` flag should be set if the operatio is
+ * reorderable.
+ *
+ * NOTE: `item` can be `NULL` when a user passed a custom initial value, in
+ * this case only the `reorderable` flag is valid.
+ */
+typedef int (get_identity_function)(
+ PyArrayMethod_Context *context, char *item,
+ NPY_ARRAYMETHOD_IDENTITY_FLAGS *flags);
+
/*
* ****************************
diff --git a/numpy/core/src/multiarray/array_method.c b/numpy/core/src/multiarray/array_method.c
index a289e62ab..e39359ea7 100644
--- a/numpy/core/src/multiarray/array_method.c
+++ b/numpy/core/src/multiarray/array_method.c
@@ -337,6 +337,7 @@ fill_arraymethod_from_slots(
continue;
case NPY_METH_get_identity:
meth->get_identity = slot->pfunc;
+ continue;
default:
break;
}
diff --git a/numpy/core/src/umath/wrapping_array_method.c b/numpy/core/src/umath/wrapping_array_method.c
index 9f8f036e8..4d0c4caa4 100644
--- a/numpy/core/src/umath/wrapping_array_method.c
+++ b/numpy/core/src/umath/wrapping_array_method.c
@@ -177,6 +177,38 @@ wrapping_method_get_loop(
}
+/*
+ * Wraps the original identity function, needs to translate the descriptors
+ * back to the original ones and provide an "original" context (identically to
+ * `get_loop`).
+ * We assume again that translating the descriptors is quick.
+ */
+static int
+wrapping_method_get_identity_function(PyArrayMethod_Context *context,
+ char *item, NPY_ARRAYMETHOD_IDENTITY_FLAGS *flags)
+{
+ /* Copy the context, and replace descriptors: */
+ PyArrayMethod_Context orig_context = *context;
+ PyArray_Descr *orig_descrs[NPY_MAXARGS];
+ orig_context.descriptors = orig_descrs;
+ orig_context.method = context->method->wrapped_meth;
+
+ int nin = context->method->nin, nout = context->method->nout;
+ PyArray_DTypeMeta **dtypes = context->method->wrapped_dtypes;
+
+ if (context->method->translate_given_descrs(
+ nin, nout, dtypes, context->descriptors, orig_descrs) < 0) {
+ return -1;
+ }
+ int res = context->method->wrapped_meth->get_identity(&orig_context,
+ item, flags);
+ for (int i = 0; i < nin + nout; i++) {
+ Py_DECREF(orig_descrs);
+ }
+ return res;
+}
+
+
/**
* Allows creating of a fairly lightweight wrapper around an existing ufunc
* loop. The idea is mainly for units, as this is currently slightly limited
@@ -243,6 +275,7 @@ PyUFunc_AddWrappingLoop(PyObject *ufunc_obj,
PyType_Slot slots[] = {
{NPY_METH_resolve_descriptors, &wrapping_method_resolve_descriptors},
{NPY_METH_get_loop, &wrapping_method_get_loop},
+ {NPY_METH_get_identity, &wrapping_method_get_identity_function},
{0, NULL}
};