diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-04-17 23:41:36 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-11-12 09:14:50 -0800 |
commit | d3b2036949255e48ecbcfcc70ed2ea95c755cf2a (patch) | |
tree | 703569f42f06539eb1c7a95c841d5a271079ea5a /numpy/core | |
parent | 97df928718a46b869d0d6675ffd6e8c539f32773 (diff) | |
download | numpy-d3b2036949255e48ecbcfcc70ed2ea95c755cf2a.tar.gz |
ENH: Allow ufunc.identity to be any python object
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/code_generators/numpy_api.py | 3 | ||||
-rw-r--r-- | numpy/core/include/numpy/ufuncobject.h | 9 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 27 |
3 files changed, 38 insertions, 1 deletions
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py index d8a9ee6b4..fdf97ac00 100644 --- a/numpy/core/code_generators/numpy_api.py +++ b/numpy/core/code_generators/numpy_api.py @@ -402,6 +402,9 @@ ufunc_funcs_api = { # End 1.7 API 'PyUFunc_RegisterLoopForDescr': (41,), # End 1.8 API + 'PyUFunc_FromFuncAndDataAndSignatureAndIdentity': + (42,), + # End 1.16 API } # List of all the dicts which define the C API diff --git a/numpy/core/include/numpy/ufuncobject.h b/numpy/core/include/numpy/ufuncobject.h index 85f8a6c08..90d837a9b 100644 --- a/numpy/core/include/numpy/ufuncobject.h +++ b/numpy/core/include/numpy/ufuncobject.h @@ -223,7 +223,8 @@ typedef struct _tagPyUFuncObject { */ npy_uint32 *core_dim_flags; - + /* Identity for reduction, when identity == PyUFunc_IdentityValue */ + PyObject *identity_value; } PyUFuncObject; @@ -299,6 +300,12 @@ typedef struct _tagPyUFuncObject { * This case allows reduction with multiple axes at once. */ #define PyUFunc_ReorderableNone -2 +/* + * UFunc unit is in identity_value, and the order of operations can be reordered + * This case allows reduction with multiple axes at once. + */ +#define PyUFunc_IdentityValue -3 + #define UFUNC_REDUCE 0 #define UFUNC_ACCUMULATE 1 diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index e60c734ec..1fe8745a0 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -2453,6 +2453,11 @@ _get_identity(PyUFuncObject *ufunc, npy_bool *reorderable) { *reorderable = 0; Py_RETURN_NONE; + case PyUFunc_IdentityValue: + *reorderable = 1; + Py_INCREF(ufunc->identity_value); + return ufunc->identity_value; + default: PyErr_Format(PyExc_ValueError, "ufunc %s has an invalid identity", ufunc_get_name_cstr(ufunc)); @@ -4833,6 +4838,20 @@ PyUFunc_FromFuncAndDataAndSignature(PyUFuncGenericFunction *func, void **data, const char *name, const char *doc, int unused, const char *signature) { + return PyUFunc_FromFuncAndDataAndSignatureAndIdentity( + func, data, types, ntypes, nin, nout, identity, name, doc, + unused, signature, NULL); +} + +/*UFUNC_API*/ +NPY_NO_EXPORT PyObject * +PyUFunc_FromFuncAndDataAndSignatureAndIdentity(PyUFuncGenericFunction *func, void **data, + char *types, int ntypes, + int nin, int nout, int identity, + const char *name, const char *doc, + int unused, const char *signature, + PyObject *identity_value) +{ PyUFuncObject *ufunc; if (nin + nout > NPY_MAXARGS) { PyErr_Format(PyExc_ValueError, @@ -4853,6 +4872,10 @@ PyUFunc_FromFuncAndDataAndSignature(PyUFuncGenericFunction *func, void **data, ufunc->nout = nout; ufunc->nargs = nin+nout; ufunc->identity = identity; + if (ufunc->identity == PyUFunc_IdentityValue) { + Py_INCREF(identity_value); + } + ufunc->identity_value = identity_value; ufunc->functions = func; ufunc->data = data; @@ -4874,6 +4897,7 @@ PyUFunc_FromFuncAndDataAndSignature(PyUFuncGenericFunction *func, void **data, ufunc->op_flags = PyArray_malloc(sizeof(npy_uint32)*ufunc->nargs); if (ufunc->op_flags == NULL) { + Py_DECREF(ufunc); return PyErr_NoMemory(); } memset(ufunc->op_flags, 0, sizeof(npy_uint32)*ufunc->nargs); @@ -5230,6 +5254,9 @@ ufunc_dealloc(PyUFuncObject *ufunc) PyArray_free(ufunc->op_flags); Py_XDECREF(ufunc->userloops); Py_XDECREF(ufunc->obj); + if (ufunc->identity == PyUFunc_IdentityValue) { + Py_DECREF(ufunc->identity_value); + } PyArray_free(ufunc); } |