diff options
| author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-02-02 18:26:38 -0600 |
|---|---|---|
| committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-01-20 15:28:48 +0100 |
| commit | 581bb3796a97c2eec634a390e9c7befb5917859f (patch) | |
| tree | e0723fa907e90701a9ae0a1bb848ecba65659bf2 /numpy | |
| parent | ffe18c971d4d122dc36cafcc6651b44118d08d39 (diff) | |
| download | numpy-581bb3796a97c2eec634a390e9c7befb5917859f.tar.gz | |
ENH: Support identity-function in experimental DType API
Also add it to the wrapped array-method (ufunc) implementation
so that a Unit dtype can reasonably use an identity for reductions.
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/include/numpy/experimental_dtype_api.h | 32 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/array_method.c | 1 | ||||
| -rw-r--r-- | numpy/core/src/umath/wrapping_array_method.c | 33 |
3 files changed, 66 insertions, 0 deletions
diff --git a/numpy/core/include/numpy/experimental_dtype_api.h b/numpy/core/include/numpy/experimental_dtype_api.h index 5f5f24317..48941baec 100644 --- a/numpy/core/include/numpy/experimental_dtype_api.h +++ b/numpy/core/include/numpy/experimental_dtype_api.h @@ -321,6 +321,38 @@ typedef int (PyArrayMethod_StridedLoop)(PyArrayMethod_Context *context, NpyAuxData *transferdata); +/* + * For reductions, NumPy sometimes requires an identity or default value. + * The typical way of relying on a single "default" for ufuncs does not always + * work however. This function allows customizing the identity value as well + * as whether the operation is "reorderable". + */ +#define NPY_METH_get_identity 7 + +typedef enum { + /* The value can be used as a default for empty reductions */ + NPY_METH_ITEM_IS_DEFAULT = 1 << 0, + /* The value represents the identity value */ + NPY_METH_ITEM_IS_IDENTITY = 1 << 1, + /* The operation is fully reorderable (iteration order may be optimized) */ + NPY_METH_IS_REORDERABLE = 1 << 2, +} NPY_ARRAYMETHOD_IDENTITY_FLAGS; + +/* + * If an identity exists, should set the `NPY_METH_ITEM_IS_IDENTITY`, normally + * the `NPY_METH_ITEM_IS_DEFAULT` should also be set, but it is distinct. + * By default NumPy provides a "default" for `object` dtype, but does not use + * it as an identity. + * The `NPY_METH_IS_REORDERABLE` flag should be set if the operatio is + * reorderable. + * + * NOTE: `item` can be `NULL` when a user passed a custom initial value, in + * this case only the `reorderable` flag is valid. + */ +typedef int (get_identity_function)( + PyArrayMethod_Context *context, char *item, + NPY_ARRAYMETHOD_IDENTITY_FLAGS *flags); + /* * **************************** diff --git a/numpy/core/src/multiarray/array_method.c b/numpy/core/src/multiarray/array_method.c index a289e62ab..e39359ea7 100644 --- a/numpy/core/src/multiarray/array_method.c +++ b/numpy/core/src/multiarray/array_method.c @@ -337,6 +337,7 @@ fill_arraymethod_from_slots( continue; case NPY_METH_get_identity: meth->get_identity = slot->pfunc; + continue; default: break; } diff --git a/numpy/core/src/umath/wrapping_array_method.c b/numpy/core/src/umath/wrapping_array_method.c index 9f8f036e8..4d0c4caa4 100644 --- a/numpy/core/src/umath/wrapping_array_method.c +++ b/numpy/core/src/umath/wrapping_array_method.c @@ -177,6 +177,38 @@ wrapping_method_get_loop( } +/* + * Wraps the original identity function, needs to translate the descriptors + * back to the original ones and provide an "original" context (identically to + * `get_loop`). + * We assume again that translating the descriptors is quick. + */ +static int +wrapping_method_get_identity_function(PyArrayMethod_Context *context, + char *item, NPY_ARRAYMETHOD_IDENTITY_FLAGS *flags) +{ + /* Copy the context, and replace descriptors: */ + PyArrayMethod_Context orig_context = *context; + PyArray_Descr *orig_descrs[NPY_MAXARGS]; + orig_context.descriptors = orig_descrs; + orig_context.method = context->method->wrapped_meth; + + int nin = context->method->nin, nout = context->method->nout; + PyArray_DTypeMeta **dtypes = context->method->wrapped_dtypes; + + if (context->method->translate_given_descrs( + nin, nout, dtypes, context->descriptors, orig_descrs) < 0) { + return -1; + } + int res = context->method->wrapped_meth->get_identity(&orig_context, + item, flags); + for (int i = 0; i < nin + nout; i++) { + Py_DECREF(orig_descrs); + } + return res; +} + + /** * Allows creating of a fairly lightweight wrapper around an existing ufunc * loop. The idea is mainly for units, as this is currently slightly limited @@ -243,6 +275,7 @@ PyUFunc_AddWrappingLoop(PyObject *ufunc_obj, PyType_Slot slots[] = { {NPY_METH_resolve_descriptors, &wrapping_method_resolve_descriptors}, {NPY_METH_get_loop, &wrapping_method_get_loop}, + {NPY_METH_get_identity, &wrapping_method_get_identity_function}, {0, NULL} }; |
