summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/overrides.py23
-rw-r--r--numpy/core/tests/test_overrides.py39
2 files changed, 61 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index cb550152e..663436a4c 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -2,6 +2,7 @@
import collections
import functools
import os
+import sys
from numpy.core._multiarray_umath import (
add_docstring, implement_array_function, _get_implementing_args)
@@ -176,7 +177,27 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
@functools.wraps(implementation)
def public_api(*args, **kwargs):
- relevant_args = dispatcher(*args, **kwargs)
+ try:
+ relevant_args = dispatcher(*args, **kwargs)
+ except TypeError as exc:
+ # Try to clean up a signature related TypeError. Such an
+ # error will be something like:
+ # dispatcher.__name__() got an unexpected keyword argument
+ #
+ # So replace the dispatcher name in this case. In principle
+ # TypeErrors may be raised from _within_ the dispatcher, so
+ # we check that the traceback contains a string that starts
+ # with the name. (In principle we could also check the
+ # traceback length, as it would be deeper.)
+ msg = exc.args[0]
+ disp_name = dispatcher.__name__
+ if not isinstance(msg, str) or not msg.startswith(disp_name):
+ raise
+
+ # Replace with the correct name and re-raise:
+ new_msg = msg.replace(disp_name, public_api.__name__)
+ raise TypeError(new_msg) from None
+
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 36970dbc0..e68406ebd 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -355,6 +355,45 @@ class TestArrayFunctionImplementation:
TypeError, "no implementation found for 'my.func'"):
func(MyArray())
+ def test_signature_error_message(self):
+ # The lambda function will be named "<lambda>", but the TypeError
+ # should show the name as "func"
+ def _dispatcher():
+ return ()
+
+ @array_function_dispatch(_dispatcher)
+ def func():
+ pass
+
+ try:
+ func(bad_arg=3)
+ except TypeError as e:
+ expected_exception = e
+
+ try:
+ func(bad_arg=3)
+ raise AssertionError("must fail")
+ except TypeError as exc:
+ assert exc.args == expected_exception.args
+
+ @pytest.mark.parametrize("value", [234, "this func is not replaced"])
+ def test_dispatcher_error(self, value):
+ # If the dispatcher raises an error, we must not attempt to mutate it
+ error = TypeError(value)
+
+ def dispatcher():
+ raise error
+
+ @array_function_dispatch(dispatcher)
+ def func():
+ return 3
+
+ try:
+ func()
+ raise AssertionError("must fail")
+ except TypeError as exc:
+ assert exc is error # unmodified exception
+
class TestNDArrayMethods: