summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c85
1 files changed, 84 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index 04768504e..3c55e2164 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -419,6 +419,9 @@ typedef struct {
PyObject *dict;
PyObject *relevant_arg_func;
PyObject *default_impl;
+ /* The following fields are used to clean up TypeError messages only: */
+ PyObject *dispatcher_name;
+ PyObject *public_name;
} PyArray_ArrayFunctionDispatcherObject;
@@ -428,10 +431,72 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self)
Py_CLEAR(self->relevant_arg_func);
Py_CLEAR(self->default_impl);
Py_CLEAR(self->dict);
+ Py_CLEAR(self->dispatcher_name);
+ Py_CLEAR(self->public_name);
PyObject_FREE(self);
}
+static void
+fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self)
+{
+ if (!PyErr_ExceptionMatches(PyExc_TypeError)) {
+ return;
+ }
+
+ PyObject *exc, *val, *tb, *message;
+ PyErr_Fetch(&exc, &val, &tb);
+
+ if (!PyUnicode_CheckExact(val)) {
+ /*
+ * We expect the error to be unnormalized, but maybe it isn't always
+ * the case, so normalize and fetch args[0] if it isn't a string.
+ */
+ PyErr_NormalizeException(&exc, &val, &tb);
+
+ PyObject *args = PyObject_GetAttrString(val, "args");
+ if (args == NULL || !PyTuple_CheckExact(args)
+ || PyTuple_GET_SIZE(args) != 1) {
+ Py_XDECREF(args);
+ goto restore_error;
+ }
+ message = PyTuple_GET_ITEM(args, 0);
+ Py_INCREF(message);
+ Py_DECREF(args);
+ if (!PyUnicode_CheckExact(message)) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ }
+ else {
+ Py_INCREF(val);
+ message = val;
+ }
+
+ Py_ssize_t cmp = PyUnicode_Tailmatch(
+ message, self->dispatcher_name, 0, -1, -1);
+ if (cmp <= 0) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ Py_SETREF(message, PyUnicode_Replace(
+ message, self->dispatcher_name, self->public_name, 1));
+ if (message == NULL) {
+ goto restore_error;
+ }
+ PyErr_SetObject(PyExc_TypeError, message);
+ Py_DECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ Py_DECREF(message);
+ return;
+
+ restore_error:
+ /* replacement not successful, so restore original error */
+ PyErr_Restore(exc, val, tb);
+}
+
+
static PyObject *
dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
@@ -458,6 +523,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
relevant_args = PyObject_Vectorcall(
self->relevant_arg_func, args, len_args, kwnames);
if (relevant_args == NULL) {
+ fix_name_if_typeerror(self);
return NULL;
}
Py_SETREF(relevant_args, PySequence_Fast(relevant_args,
@@ -600,14 +666,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs)
}
self->vectorcall = (vectorcallfunc)dispatcher_vectorcall;
+ Py_INCREF(self->default_impl);
+ self->dict = NULL;
+ self->dispatcher_name = NULL;
+ self->public_name = NULL;
+
if (self->relevant_arg_func == Py_None) {
/* NULL in the relevant arg function means we use `like=` */
Py_CLEAR(self->relevant_arg_func);
}
else {
+ /* Fetch names to clean up TypeErrors (show actual name) */
Py_INCREF(self->relevant_arg_func);
+ self->dispatcher_name = PyObject_GetAttrString(
+ self->relevant_arg_func, "__qualname__");
+ if (self->dispatcher_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
+ self->public_name = PyObject_GetAttrString(
+ self->default_impl, "__qualname__");
+ if (self->public_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
}
- Py_INCREF(self->default_impl);
/* Need to be like a Python function that has arbitrary attributes */
self->dict = PyDict_New();