diff options
author | pdmurray <peynmurray@gmail.com> | 2023-02-07 13:36:06 -0800 |
---|---|---|
committer | pdmurray <peynmurray@gmail.com> | 2023-02-17 14:34:41 -0800 |
commit | 2d0853d4a454b198e91c20eec705e7bc8ae148d5 (patch) | |
tree | 618905a8ccdec2bf0176018527a2465e456db5b8 | |
parent | 51ecf84013a6c8325f9dd0571a48794fc0e448a4 (diff) | |
download | numpy-2d0853d4a454b198e91c20eec705e7bc8ae148d5.tar.gz |
ENH: Add PyArray_ArrFunc compare support for NEP42 dtypes
-rw-r--r-- | numpy/core/include/numpy/_dtype_api.h | 45 | ||||
-rw-r--r-- | numpy/core/src/multiarray/experimental_public_dtype_api.c | 90 |
2 files changed, 125 insertions, 10 deletions
diff --git a/numpy/core/include/numpy/_dtype_api.h b/numpy/core/include/numpy/_dtype_api.h index 8ef879bfa..b76e52381 100644 --- a/numpy/core/include/numpy/_dtype_api.h +++ b/numpy/core/include/numpy/_dtype_api.h @@ -5,7 +5,7 @@ #ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ #define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ -#define __EXPERIMENTAL_DTYPE_API_VERSION 7 +#define __EXPERIMENTAL_DTYPE_API_VERSION 8 struct PyArrayMethodObject_tag; @@ -98,7 +98,7 @@ typedef enum { typedef struct PyArrayMethod_Context_tag { /* The caller, which is typically the original ufunc. May be NULL */ - PyObject *caller; + PyObject *caller; /* The method "self". Publically currentl an opaque object. */ struct PyArrayMethodObject_tag *method; @@ -125,7 +125,7 @@ typedef struct { /* * ArrayMethod slots * ----------------- - * + * * SLOTS IDs For the ArrayMethod creation, once fully public, IDs are fixed * but can be deprecated and arbitrarily extended. */ @@ -142,6 +142,7 @@ typedef struct { /* other slots are in order, so note last one (internal use!) */ #define _NPY_NUM_DTYPE_SLOTS 8 +#define _NPY_NUM_DTYPE_PYARRAY_ARRFUNC_SLOTS 22 + (1 << 10) /* * The resolve descriptors function, must be able to handle NULL values for @@ -273,6 +274,44 @@ typedef int translate_loop_descrs_func(int nin, int nout, #define NPY_DT_setitem 7 #define NPY_DT_getitem 8 +// These PyArray_ArrFunc slots will be deprecated and replaced eventually +// getitem and setitem can be defined as a performance optimization; +// by default the user dtypes call `legacy_getitem_using_DType` and +// `legacy_setitem_using_DType`, respectively. This functionality is +// only supported for basic NumPy DTypes. + +// Cast is disabled +// #define NPY_DT_PyArray_ArrFuncs_cast 0 + (1 << 10) + +#define NPY_DT_PyArray_ArrFuncs_getitem 1 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_setitem 2 + (1 << 10) + +#define NPY_DT_PyArray_ArrFuncs_copyswapn 3 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_copyswap 4 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_compare 5 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_argmax 6 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_dotfunc 7 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_scanfunc 8 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_fromstr 9 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_nonzero 10 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_fill 11 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_fillwithscalar 12 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_sort 13 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_argsort 14 + (1 << 10) + +// Casting related slots are disabled. See +// https://github.com/numpy/numpy/pull/23173#discussion_r1101098163 +// #define NPY_DT_PyArray_ArrFuncs_castdict 15 + (1 << 10) +// #define NPY_DT_PyArray_ArrFuncs_scalarkind 16 + (1 << 10) +// #define NPY_DT_PyArray_ArrFuncs_cancastscalarkindto 17 + (1 << 10) +// #define NPY_DT_PyArray_ArrFuncs_cancastto 18 + (1 << 10) + +// These are deprecated in NumPy 1.19, so are disabled here. +// #define NPY_DT_PyArray_ArrFuncs_fastclip 19 + (1 << 10) +// #define NPY_DT_PyArray_ArrFuncs_fastputmask 20 + (1 << 10) +// #define NPY_DT_PyArray_ArrFuncs_fasttake 21 + (1 << 10) +#define NPY_DT_PyArray_ArrFuncs_argmin 22 + (1 << 10) + // TODO: These slots probably still need some thought, and/or a way to "grow"? typedef struct { diff --git a/numpy/core/src/multiarray/experimental_public_dtype_api.c b/numpy/core/src/multiarray/experimental_public_dtype_api.c index a974bf390..63a9aa0a8 100644 --- a/numpy/core/src/multiarray/experimental_public_dtype_api.c +++ b/numpy/core/src/multiarray/experimental_public_dtype_api.c @@ -16,7 +16,6 @@ #include "common_dtype.h" - static PyArray_DTypeMeta * dtype_does_not_promote( PyArray_DTypeMeta *NPY_UNUSED(self), PyArray_DTypeMeta *NPY_UNUSED(other)) @@ -162,6 +161,7 @@ PyArrayInitDTypeMeta_FromSpec( NPY_DT_SLOTS(DType)->common_instance = NULL; NPY_DT_SLOTS(DType)->setitem = NULL; NPY_DT_SLOTS(DType)->getitem = NULL; + NPY_DT_SLOTS(DType)->f = default_funcs; PyType_Slot *spec_slot = spec->slots; while (1) { @@ -171,7 +171,7 @@ PyArrayInitDTypeMeta_FromSpec( if (slot == 0) { break; } - if (slot > _NPY_NUM_DTYPE_SLOTS || slot < 0) { + if (slot > _NPY_NUM_DTYPE_PYARRAY_ARRFUNC_SLOTS || slot < 0) { PyErr_Format(PyExc_RuntimeError, "Invalid slot with value %d passed in.", slot); return -1; @@ -180,11 +180,88 @@ PyArrayInitDTypeMeta_FromSpec( * It is up to the user to get this right, and slots are sorted * exactly like they are stored right now: */ - void **current = (void **)(&( - NPY_DT_SLOTS(DType)->discover_descr_from_pyobject)); - current += slot - 1; - *current = pfunc; + if (slot <= _NPY_NUM_DTYPE_SLOTS) { + // slot > 8 are PyArray_ArrFuncs + void **current = (void **)(&( + NPY_DT_SLOTS(DType)->discover_descr_from_pyobject)); + current += slot - 1; + *current = pfunc; + } + else { + // Remove PyArray_ArrFuncs offset + int f_slot = slot - (1 << 10); + if (1 <= f_slot && f_slot <= 22) { + switch (f_slot) { + case 1: + NPY_DT_SLOTS(DType)->f.getitem = pfunc; + break; + case 2: + NPY_DT_SLOTS(DType)->f.setitem = pfunc; + break; + case 3: + NPY_DT_SLOTS(DType)->f.copyswapn = pfunc; + break; + case 4: + NPY_DT_SLOTS(DType)->f.copyswap = pfunc; + break; + case 5: + NPY_DT_SLOTS(DType)->f.compare = pfunc; + break; + case 6: + NPY_DT_SLOTS(DType)->f.argmax = pfunc; + break; + case 7: + NPY_DT_SLOTS(DType)->f.dotfunc = pfunc; + break; + case 8: + NPY_DT_SLOTS(DType)->f.scanfunc = pfunc; + break; + case 9: + NPY_DT_SLOTS(DType)->f.fromstr = pfunc; + break; + case 10: + NPY_DT_SLOTS(DType)->f.nonzero = pfunc; + break; + case 11: + NPY_DT_SLOTS(DType)->f.fill = pfunc; + break; + case 12: + NPY_DT_SLOTS(DType)->f.fillwithscalar = pfunc; + break; + case 13: + *NPY_DT_SLOTS(DType)->f.sort = pfunc; + break; + case 14: + *NPY_DT_SLOTS(DType)->f.argsort = pfunc; + break; + case 15: + case 16: + case 17: + case 18: + case 19: + case 20: + case 21: + PyErr_Format( + PyExc_RuntimeError, + "PyArray_ArrFunc casting slot with value %d is disabled.", + f_slot + ); + return -1; + case 22: + NPY_DT_SLOTS(DType)->f.argmin = pfunc; + break; + } + } else { + PyErr_Format( + PyExc_RuntimeError, + "Invalid PyArray_ArrFunc slot with value %d passed in.", + f_slot + ); + return -1; + } + } } + if (NPY_DT_SLOTS(DType)->setitem == NULL || NPY_DT_SLOTS(DType)->getitem == NULL) { PyErr_SetString(PyExc_RuntimeError, @@ -213,7 +290,6 @@ PyArrayInitDTypeMeta_FromSpec( return -1; } } - NPY_DT_SLOTS(DType)->f = default_funcs; /* invalid type num. Ideally, we get away with it! */ DType->type_num = -1; |