summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpdmurray <peynmurray@gmail.com>2023-02-07 13:36:06 -0800
committerpdmurray <peynmurray@gmail.com>2023-02-17 14:34:41 -0800
commit2d0853d4a454b198e91c20eec705e7bc8ae148d5 (patch)
tree618905a8ccdec2bf0176018527a2465e456db5b8
parent51ecf84013a6c8325f9dd0571a48794fc0e448a4 (diff)
downloadnumpy-2d0853d4a454b198e91c20eec705e7bc8ae148d5.tar.gz
ENH: Add PyArray_ArrFunc compare support for NEP42 dtypes
-rw-r--r--numpy/core/include/numpy/_dtype_api.h45
-rw-r--r--numpy/core/src/multiarray/experimental_public_dtype_api.c90
2 files changed, 125 insertions, 10 deletions
diff --git a/numpy/core/include/numpy/_dtype_api.h b/numpy/core/include/numpy/_dtype_api.h
index 8ef879bfa..b76e52381 100644
--- a/numpy/core/include/numpy/_dtype_api.h
+++ b/numpy/core/include/numpy/_dtype_api.h
@@ -5,7 +5,7 @@
#ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_
#define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_
-#define __EXPERIMENTAL_DTYPE_API_VERSION 7
+#define __EXPERIMENTAL_DTYPE_API_VERSION 8
struct PyArrayMethodObject_tag;
@@ -98,7 +98,7 @@ typedef enum {
typedef struct PyArrayMethod_Context_tag {
/* The caller, which is typically the original ufunc. May be NULL */
- PyObject *caller;
+ PyObject *caller;
/* The method "self". Publically currentl an opaque object. */
struct PyArrayMethodObject_tag *method;
@@ -125,7 +125,7 @@ typedef struct {
/*
* ArrayMethod slots
* -----------------
- *
+ *
* SLOTS IDs For the ArrayMethod creation, once fully public, IDs are fixed
* but can be deprecated and arbitrarily extended.
*/
@@ -142,6 +142,7 @@ typedef struct {
/* other slots are in order, so note last one (internal use!) */
#define _NPY_NUM_DTYPE_SLOTS 8
+#define _NPY_NUM_DTYPE_PYARRAY_ARRFUNC_SLOTS 22 + (1 << 10)
/*
* The resolve descriptors function, must be able to handle NULL values for
@@ -273,6 +274,44 @@ typedef int translate_loop_descrs_func(int nin, int nout,
#define NPY_DT_setitem 7
#define NPY_DT_getitem 8
+// These PyArray_ArrFunc slots will be deprecated and replaced eventually
+// getitem and setitem can be defined as a performance optimization;
+// by default the user dtypes call `legacy_getitem_using_DType` and
+// `legacy_setitem_using_DType`, respectively. This functionality is
+// only supported for basic NumPy DTypes.
+
+// Cast is disabled
+// #define NPY_DT_PyArray_ArrFuncs_cast 0 + (1 << 10)
+
+#define NPY_DT_PyArray_ArrFuncs_getitem 1 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_setitem 2 + (1 << 10)
+
+#define NPY_DT_PyArray_ArrFuncs_copyswapn 3 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_copyswap 4 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_compare 5 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_argmax 6 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_dotfunc 7 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_scanfunc 8 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_fromstr 9 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_nonzero 10 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_fill 11 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_fillwithscalar 12 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_sort 13 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_argsort 14 + (1 << 10)
+
+// Casting related slots are disabled. See
+// https://github.com/numpy/numpy/pull/23173#discussion_r1101098163
+// #define NPY_DT_PyArray_ArrFuncs_castdict 15 + (1 << 10)
+// #define NPY_DT_PyArray_ArrFuncs_scalarkind 16 + (1 << 10)
+// #define NPY_DT_PyArray_ArrFuncs_cancastscalarkindto 17 + (1 << 10)
+// #define NPY_DT_PyArray_ArrFuncs_cancastto 18 + (1 << 10)
+
+// These are deprecated in NumPy 1.19, so are disabled here.
+// #define NPY_DT_PyArray_ArrFuncs_fastclip 19 + (1 << 10)
+// #define NPY_DT_PyArray_ArrFuncs_fastputmask 20 + (1 << 10)
+// #define NPY_DT_PyArray_ArrFuncs_fasttake 21 + (1 << 10)
+#define NPY_DT_PyArray_ArrFuncs_argmin 22 + (1 << 10)
+
// TODO: These slots probably still need some thought, and/or a way to "grow"?
typedef struct {
diff --git a/numpy/core/src/multiarray/experimental_public_dtype_api.c b/numpy/core/src/multiarray/experimental_public_dtype_api.c
index a974bf390..63a9aa0a8 100644
--- a/numpy/core/src/multiarray/experimental_public_dtype_api.c
+++ b/numpy/core/src/multiarray/experimental_public_dtype_api.c
@@ -16,7 +16,6 @@
#include "common_dtype.h"
-
static PyArray_DTypeMeta *
dtype_does_not_promote(
PyArray_DTypeMeta *NPY_UNUSED(self), PyArray_DTypeMeta *NPY_UNUSED(other))
@@ -162,6 +161,7 @@ PyArrayInitDTypeMeta_FromSpec(
NPY_DT_SLOTS(DType)->common_instance = NULL;
NPY_DT_SLOTS(DType)->setitem = NULL;
NPY_DT_SLOTS(DType)->getitem = NULL;
+ NPY_DT_SLOTS(DType)->f = default_funcs;
PyType_Slot *spec_slot = spec->slots;
while (1) {
@@ -171,7 +171,7 @@ PyArrayInitDTypeMeta_FromSpec(
if (slot == 0) {
break;
}
- if (slot > _NPY_NUM_DTYPE_SLOTS || slot < 0) {
+ if (slot > _NPY_NUM_DTYPE_PYARRAY_ARRFUNC_SLOTS || slot < 0) {
PyErr_Format(PyExc_RuntimeError,
"Invalid slot with value %d passed in.", slot);
return -1;
@@ -180,11 +180,88 @@ PyArrayInitDTypeMeta_FromSpec(
* It is up to the user to get this right, and slots are sorted
* exactly like they are stored right now:
*/
- void **current = (void **)(&(
- NPY_DT_SLOTS(DType)->discover_descr_from_pyobject));
- current += slot - 1;
- *current = pfunc;
+ if (slot <= _NPY_NUM_DTYPE_SLOTS) {
+ // slot > 8 are PyArray_ArrFuncs
+ void **current = (void **)(&(
+ NPY_DT_SLOTS(DType)->discover_descr_from_pyobject));
+ current += slot - 1;
+ *current = pfunc;
+ }
+ else {
+ // Remove PyArray_ArrFuncs offset
+ int f_slot = slot - (1 << 10);
+ if (1 <= f_slot && f_slot <= 22) {
+ switch (f_slot) {
+ case 1:
+ NPY_DT_SLOTS(DType)->f.getitem = pfunc;
+ break;
+ case 2:
+ NPY_DT_SLOTS(DType)->f.setitem = pfunc;
+ break;
+ case 3:
+ NPY_DT_SLOTS(DType)->f.copyswapn = pfunc;
+ break;
+ case 4:
+ NPY_DT_SLOTS(DType)->f.copyswap = pfunc;
+ break;
+ case 5:
+ NPY_DT_SLOTS(DType)->f.compare = pfunc;
+ break;
+ case 6:
+ NPY_DT_SLOTS(DType)->f.argmax = pfunc;
+ break;
+ case 7:
+ NPY_DT_SLOTS(DType)->f.dotfunc = pfunc;
+ break;
+ case 8:
+ NPY_DT_SLOTS(DType)->f.scanfunc = pfunc;
+ break;
+ case 9:
+ NPY_DT_SLOTS(DType)->f.fromstr = pfunc;
+ break;
+ case 10:
+ NPY_DT_SLOTS(DType)->f.nonzero = pfunc;
+ break;
+ case 11:
+ NPY_DT_SLOTS(DType)->f.fill = pfunc;
+ break;
+ case 12:
+ NPY_DT_SLOTS(DType)->f.fillwithscalar = pfunc;
+ break;
+ case 13:
+ *NPY_DT_SLOTS(DType)->f.sort = pfunc;
+ break;
+ case 14:
+ *NPY_DT_SLOTS(DType)->f.argsort = pfunc;
+ break;
+ case 15:
+ case 16:
+ case 17:
+ case 18:
+ case 19:
+ case 20:
+ case 21:
+ PyErr_Format(
+ PyExc_RuntimeError,
+ "PyArray_ArrFunc casting slot with value %d is disabled.",
+ f_slot
+ );
+ return -1;
+ case 22:
+ NPY_DT_SLOTS(DType)->f.argmin = pfunc;
+ break;
+ }
+ } else {
+ PyErr_Format(
+ PyExc_RuntimeError,
+ "Invalid PyArray_ArrFunc slot with value %d passed in.",
+ f_slot
+ );
+ return -1;
+ }
+ }
}
+
if (NPY_DT_SLOTS(DType)->setitem == NULL
|| NPY_DT_SLOTS(DType)->getitem == NULL) {
PyErr_SetString(PyExc_RuntimeError,
@@ -213,7 +290,6 @@ PyArrayInitDTypeMeta_FromSpec(
return -1;
}
}
- NPY_DT_SLOTS(DType)->f = default_funcs;
/* invalid type num. Ideally, we get away with it! */
DType->type_num = -1;