summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/usertypes.c67
1 files changed, 63 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index d0cf53576..a338d712d 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -268,6 +268,56 @@ PyArray_RegisterDataType(PyArray_Descr *descr)
return typenum;
}
+
+/*
+ * Checks that there is no cast already cached using the new casting-impl
+ * mechanism.
+ * In that case, we do not clear out the cache (but otherwise silently
+ * continue). Users should not modify casts after they have been used,
+ * but this may also happen accidentally during setup (and may never have
+ * mattered). See https://github.com/numpy/numpy/issues/20009
+ */
+static int _warn_if_cast_exists_already(
+ PyArray_Descr *descr, int totype, char *funcname)
+{
+ PyArray_DTypeMeta *to_DType = PyArray_DTypeFromTypeNum(totype);
+ if (to_DType == NULL) {
+ return -1;
+ }
+ PyObject *cast_impl = PyDict_GetItemWithError(
+ NPY_DT_SLOTS(NPY_DTYPE(descr))->castingimpls, (PyObject *)to_DType);
+ Py_DECREF(to_DType);
+ if (cast_impl == NULL) {
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ }
+ else {
+ char *extra_msg;
+ if (cast_impl == Py_None) {
+ extra_msg = "the cast will continue to be considered impossible.";
+ }
+ else {
+ extra_msg = "the previous definition will continue to be used.";
+ }
+ Py_DECREF(cast_impl);
+ PyArray_Descr *to_descr = PyArray_DescrFromType(totype);
+ int ret = PyErr_WarnFormat(PyExc_RuntimeWarning, 1,
+ "A cast from %R to %R was registered/modified using `%s` "
+ "after the cast had been used. "
+ "This registration will have (mostly) no effect: %s\n"
+ "The most likely fix is to ensure that casts are the first "
+ "thing initialized after dtype registration. "
+ "Please contact the NumPy developers with any questions!",
+ descr, to_descr, funcname, extra_msg);
+ Py_DECREF(to_descr);
+ if (ret < 0) {
+ return -1;
+ }
+ }
+ return 0;
+}
+
/*NUMPY_API
Register Casting Function
Replaces any function currently stored.
@@ -279,14 +329,19 @@ PyArray_RegisterCastFunc(PyArray_Descr *descr, int totype,
PyObject *cobj, *key;
int ret;
- if (totype < NPY_NTYPES_ABI_COMPATIBLE) {
- descr->f->cast[totype] = castfunc;
- return 0;
- }
if (totype >= NPY_NTYPES && !PyTypeNum_ISUSERDEF(totype)) {
PyErr_SetString(PyExc_TypeError, "invalid type number.");
return -1;
}
+ if (_warn_if_cast_exists_already(
+ descr, totype, "PyArray_RegisterCastFunc") < 0) {
+ return -1;
+ }
+
+ if (totype < NPY_NTYPES_ABI_COMPATIBLE) {
+ descr->f->cast[totype] = castfunc;
+ return 0;
+ }
if (descr->f->castdict == NULL) {
descr->f->castdict = PyDict_New();
if (descr->f->castdict == NULL) {
@@ -328,6 +383,10 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype,
"RegisterCanCast must be user-defined.");
return -1;
}
+ if (_warn_if_cast_exists_already(
+ descr, totype, "PyArray_RegisterCanCast") < 0) {
+ return -1;
+ }
if (scalar == NPY_NOSCALAR) {
/*