diff options
-rw-r--r-- | cmd2/cmd2.py | 35 | ||||
-rw-r--r-- | tests/test_plugin.py | 28 |
2 files changed, 55 insertions, 8 deletions
diff --git a/cmd2/cmd2.py b/cmd2/cmd2.py index 89f14b37..ba774f0b 100644 --- a/cmd2/cmd2.py +++ b/cmd2/cmd2.py @@ -34,6 +34,7 @@ import cmd import collections from colorama import Fore import glob +import inspect import os import platform import re @@ -3110,15 +3111,38 @@ Script should contain one command per line, just like command would be typed in self._precmd_hooks = [] self._postcmd_hooks = [] + @classmethod + def _validate_callable_param_count(cls, func, count): + signature = inspect.signature(func) + # validate that the callable has the right number of parameters + nparam = len(signature.parameters) + if nparam != count: + raise TypeError('{} has {} positional arguments, expected {}'.format( + func.__name__, + nparam, + count, + )) + + @classmethod + def _validate_prepostloop_callable(cls, func): + """Check parameter and return values for preloop and postloop hooks""" + cls._validate_callable_param_count(func, 0) + # make sure there is no return notation + signature = inspect.signature(func) + if signature.return_annotation != signature.empty: + raise TypeError('{} should not have a declared return type'.format( + func.__name__, + )) + def register_preloop_hook(self, func): """Register a function to be called at the beginning of the command loop.""" + self._validate_prepostloop_callable(func) self._preloop_hooks.append(func) - # TODO check signature of registered func and throw error if it's wrong def register_postloop_hook(self, func): """Register a function to be called at the end of the command loop.""" + self._validate_prepostloop_callable(func) self._postloop_hooks.append(func) - # TODO check signature of registered func and throw error if it's wrong def register_postparsing_hook(self, func): """Register a function to be called after parsing user input but before running the command""" @@ -3127,14 +3151,9 @@ Script should contain one command per line, just like command would be typed in def register_precmd_hook(self, func): """Register a function to be called before the command function.""" - import inspect signature = inspect.signature(func) # validate that the callable has the right number of parameters - nparam = len(signature.parameters) - if nparam < 1: - raise TypeError('precommand hooks must have one positional argument') - if nparam > 1: - raise TypeError('precommand hooks take one positional argument but {} were given'.format(nparam)) + self._validate_callable_param_count(func, 1) # validate the parameter has the right annotation paramname = list(signature.parameters.keys())[0] param = signature.parameters[paramname] diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 26eb88bb..f3db853b 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -33,6 +33,14 @@ class Plugin: "Another method used for preloop or postloop hooks" self.poutput("two") + def prepost_hook_too_many_parameters(self, param): + "A preloop or postloop hook with too many parameters" + pass + + def prepost_hook_with_return_type(self) -> bool: + "A preloop or postloop hook with a declared return type" + pass + def postparse_hook(self, statement: cmd2.Statement) -> Tuple[bool, cmd2.Statement]: "A postparsing hook" self.called_postparsing += 1 @@ -116,6 +124,16 @@ class PluggedApp(Plugin, cmd2.Cmd): # test pre and postloop hooks # ### +def test_register_preloop_hook_too_many_parameters(): + app = PluggedApp() + with pytest.raises(TypeError): + app.register_preloop_hook(app.prepost_hook_too_many_parameters) + +def test_register_preloop_hook_with_return_type(): + app = PluggedApp() + with pytest.raises(TypeError): + app.register_preloop_hook(app.prepost_hook_with_return_type) + def test_preloop_hook(capsys): app = PluggedApp() app.register_preloop_hook(app.prepost_hook_one) @@ -137,6 +155,16 @@ def test_preloop_hooks(capsys): assert out == 'one\ntwo\nhello\n' assert not err +def test_register_postloop_hook_too_many_parameters(): + app = PluggedApp() + with pytest.raises(TypeError): + app.register_postloop_hook(app.prepost_hook_too_many_parameters) + +def test_register_postloop_hook_with_return_type(): + app = PluggedApp() + with pytest.raises(TypeError): + app.register_postloop_hook(app.prepost_hook_with_return_type) + def test_postloop_hook(capsys): app = PluggedApp() app.register_postloop_hook(app.prepost_hook_one) |