summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd2/cmd2.py35
-rw-r--r--tests/test_plugin.py28
2 files changed, 55 insertions, 8 deletions
diff --git a/cmd2/cmd2.py b/cmd2/cmd2.py
index 89f14b37..ba774f0b 100644
--- a/cmd2/cmd2.py
+++ b/cmd2/cmd2.py
@@ -34,6 +34,7 @@ import cmd
import collections
from colorama import Fore
import glob
+import inspect
import os
import platform
import re
@@ -3110,15 +3111,38 @@ Script should contain one command per line, just like command would be typed in
self._precmd_hooks = []
self._postcmd_hooks = []
+ @classmethod
+ def _validate_callable_param_count(cls, func, count):
+ signature = inspect.signature(func)
+ # validate that the callable has the right number of parameters
+ nparam = len(signature.parameters)
+ if nparam != count:
+ raise TypeError('{} has {} positional arguments, expected {}'.format(
+ func.__name__,
+ nparam,
+ count,
+ ))
+
+ @classmethod
+ def _validate_prepostloop_callable(cls, func):
+ """Check parameter and return values for preloop and postloop hooks"""
+ cls._validate_callable_param_count(func, 0)
+ # make sure there is no return notation
+ signature = inspect.signature(func)
+ if signature.return_annotation != signature.empty:
+ raise TypeError('{} should not have a declared return type'.format(
+ func.__name__,
+ ))
+
def register_preloop_hook(self, func):
"""Register a function to be called at the beginning of the command loop."""
+ self._validate_prepostloop_callable(func)
self._preloop_hooks.append(func)
- # TODO check signature of registered func and throw error if it's wrong
def register_postloop_hook(self, func):
"""Register a function to be called at the end of the command loop."""
+ self._validate_prepostloop_callable(func)
self._postloop_hooks.append(func)
- # TODO check signature of registered func and throw error if it's wrong
def register_postparsing_hook(self, func):
"""Register a function to be called after parsing user input but before running the command"""
@@ -3127,14 +3151,9 @@ Script should contain one command per line, just like command would be typed in
def register_precmd_hook(self, func):
"""Register a function to be called before the command function."""
- import inspect
signature = inspect.signature(func)
# validate that the callable has the right number of parameters
- nparam = len(signature.parameters)
- if nparam < 1:
- raise TypeError('precommand hooks must have one positional argument')
- if nparam > 1:
- raise TypeError('precommand hooks take one positional argument but {} were given'.format(nparam))
+ self._validate_callable_param_count(func, 1)
# validate the parameter has the right annotation
paramname = list(signature.parameters.keys())[0]
param = signature.parameters[paramname]
diff --git a/tests/test_plugin.py b/tests/test_plugin.py
index 26eb88bb..f3db853b 100644
--- a/tests/test_plugin.py
+++ b/tests/test_plugin.py
@@ -33,6 +33,14 @@ class Plugin:
"Another method used for preloop or postloop hooks"
self.poutput("two")
+ def prepost_hook_too_many_parameters(self, param):
+ "A preloop or postloop hook with too many parameters"
+ pass
+
+ def prepost_hook_with_return_type(self) -> bool:
+ "A preloop or postloop hook with a declared return type"
+ pass
+
def postparse_hook(self, statement: cmd2.Statement) -> Tuple[bool, cmd2.Statement]:
"A postparsing hook"
self.called_postparsing += 1
@@ -116,6 +124,16 @@ class PluggedApp(Plugin, cmd2.Cmd):
# test pre and postloop hooks
#
###
+def test_register_preloop_hook_too_many_parameters():
+ app = PluggedApp()
+ with pytest.raises(TypeError):
+ app.register_preloop_hook(app.prepost_hook_too_many_parameters)
+
+def test_register_preloop_hook_with_return_type():
+ app = PluggedApp()
+ with pytest.raises(TypeError):
+ app.register_preloop_hook(app.prepost_hook_with_return_type)
+
def test_preloop_hook(capsys):
app = PluggedApp()
app.register_preloop_hook(app.prepost_hook_one)
@@ -137,6 +155,16 @@ def test_preloop_hooks(capsys):
assert out == 'one\ntwo\nhello\n'
assert not err
+def test_register_postloop_hook_too_many_parameters():
+ app = PluggedApp()
+ with pytest.raises(TypeError):
+ app.register_postloop_hook(app.prepost_hook_too_many_parameters)
+
+def test_register_postloop_hook_with_return_type():
+ app = PluggedApp()
+ with pytest.raises(TypeError):
+ app.register_postloop_hook(app.prepost_hook_with_return_type)
+
def test_postloop_hook(capsys):
app = PluggedApp()
app.register_postloop_hook(app.prepost_hook_one)