diff options
Diffstat (limited to 'cmd2/utils.py')
-rw-r--r-- | cmd2/utils.py | 85 |
1 files changed, 31 insertions, 54 deletions
diff --git a/cmd2/utils.py b/cmd2/utils.py index 1008cb86..717d73b4 100644 --- a/cmd2/utils.py +++ b/cmd2/utils.py @@ -3,7 +3,6 @@ import argparse import collections -import collections.abc as collections_abc import functools import glob import inspect @@ -19,24 +18,27 @@ from enum import ( ) from typing import ( IO, + TYPE_CHECKING, Any, Callable, Dict, Iterable, List, - NamedTuple, Optional, TextIO, Type, - TYPE_CHECKING, + TypeVar, Union, + cast, ) from . import ( constants, ) + if TYPE_CHECKING: # pragma: no cover - import cmd2 + import cmd2 # noqa: F401 + def is_quoted(arg: str) -> bool: """ @@ -100,15 +102,15 @@ class Settable: def __init__( self, name: str, - val_type: Callable, + val_type: Union[Type[Any], Callable[[Any], Any]], description: str, *, settable_object: Optional[object] = None, settable_attrib_name: Optional[str] = None, - onchange_cb: Callable[[str, Any, Any], Any] = None, - choices: Iterable = None, - choices_provider: Optional[Callable] = None, - completer: Optional[Callable] = None + onchange_cb: Optional[Callable[[str, Any, Any], Any]] = None, + choices: Optional[Iterable[Any]] = None, + choices_provider: Optional[Callable[[], List[str]]] = None, + completer: Optional[Callable[[str, str, int, int], List[str]]] = None ): """ Settable Initializer @@ -174,36 +176,6 @@ class Settable: return new_value -def namedtuple_with_defaults(typename: str, field_names: Union[str, List[str]], default_values: collections_abc.Iterable = ()): - """ - Convenience function for defining a namedtuple with default values - - From: https://stackoverflow.com/questions/11351032/namedtuple-and-default-values-for-optional-keyword-arguments - - Examples: - >>> Node = namedtuple_with_defaults('Node', 'val left right') - >>> Node() - Node(val=None, left=None, right=None) - >>> Node = namedtuple_with_defaults('Node', 'val left right', [1, 2, 3]) - >>> Node() - Node(val=1, left=2, right=3) - >>> Node = namedtuple_with_defaults('Node', 'val left right', {'right':7}) - >>> Node() - Node(val=None, left=None, right=7) - >>> Node(4) - Node(val=4, left=None, right=7) - """ - T: NamedTuple = collections.namedtuple(typename, field_names) - # noinspection PyProtectedMember,PyUnresolvedReferences - T.__new__.__defaults__ = (None,) * len(T._fields) - if isinstance(default_values, collections_abc.Mapping): - prototype = T(**default_values) - else: - prototype = T(*default_values) - T.__new__.__defaults__ = tuple(prototype) - return T - - def is_text_file(file_path: str) -> bool: """Returns if a file contains only ASCII or UTF-8 encoded text. @@ -241,13 +213,16 @@ def is_text_file(file_path: str) -> bool: return valid_text_file -def remove_duplicates(list_to_prune: List) -> List: +_T = TypeVar('_T') + + +def remove_duplicates(list_to_prune: List[_T]) -> List[_T]: """Removes duplicates from a list while preserving order of the items. :param list_to_prune: the list being pruned of duplicates :return: The pruned list """ - temp_dict = collections.OrderedDict() + temp_dict: collections.OrderedDict[_T, Any] = collections.OrderedDict() for item in list_to_prune: temp_dict[item] = None @@ -405,7 +380,7 @@ def find_editor() -> Optional[str]: return editor -def files_from_glob_pattern(pattern: str, access=os.F_OK) -> List[str]: +def files_from_glob_pattern(pattern: str, access: int = os.F_OK) -> List[str]: """Return a list of file paths based on a glob pattern. Only files are returned, not directories, and optionally only files for which the user has a specified access to. @@ -417,7 +392,7 @@ def files_from_glob_pattern(pattern: str, access=os.F_OK) -> List[str]: return [f for f in glob.glob(pattern) if os.path.isfile(f) and os.access(f, access)] -def files_from_glob_patterns(patterns: List[str], access=os.F_OK) -> List[str]: +def files_from_glob_patterns(patterns: List[str], access: int = os.F_OK) -> List[str]: """Return a list of file paths based on a list of glob patterns. Only files are returned, not directories, and optionally only files for which the user has a specified access to. @@ -472,7 +447,7 @@ class StdSim: Stores contents in internal buffer and optionally echos to the inner stream it is simulating. """ - def __init__(self, inner_stream, *, echo: bool = False, encoding: str = 'utf-8', errors: str = 'replace') -> None: + def __init__(self, inner_stream: TextIO, *, echo: bool = False, encoding: str = 'utf-8', errors: str = 'replace') -> None: """ StdSim Initializer :param inner_stream: the wrapped stream. Should be a TextIO or StdSim instance. @@ -540,11 +515,11 @@ class StdSim: when running unit tests because pytest sets stdout to a pytest EncodedFile object. """ try: - return self.inner_stream.line_buffering + return bool(self.inner_stream.line_buffering) except AttributeError: return False - def __getattr__(self, item: str): + def __getattr__(self, item: str) -> Any: if item in self.__dict__: return self.__dict__[item] else: @@ -701,7 +676,7 @@ class ContextFlag: def __enter__(self) -> None: self.__count += 1 - def __exit__(self, *args) -> None: + def __exit__(self, *args: Any) -> None: self.__count -= 1 if self.__count < 0: raise ValueError("count has gone below 0") @@ -1060,7 +1035,7 @@ def get_styles_in_text(text: str) -> Dict[int, str]: return styles -def categorize(func: Union[Callable, Iterable[Callable]], category: str) -> None: +def categorize(func: Union[Callable[..., Any], Iterable[Callable[..., Any]]], category: str) -> None: """Categorize a function. The help command output will group the passed function under the @@ -1085,13 +1060,13 @@ def categorize(func: Union[Callable, Iterable[Callable]], category: str) -> None for item in func: setattr(item, constants.CMD_ATTR_HELP_CATEGORY, category) else: - if inspect.ismethod(func): - setattr(func.__func__, constants.CMD_ATTR_HELP_CATEGORY, category) + if inspect.ismethod(func) and hasattr(func, '__func__'): + setattr(func.__func__, constants.CMD_ATTR_HELP_CATEGORY, category) # type: ignore[attr-defined] else: setattr(func, constants.CMD_ATTR_HELP_CATEGORY, category) -def get_defining_class(meth: Callable) -> Optional[Type]: +def get_defining_class(meth: Callable[..., Any]) -> Optional[Type[Any]]: """ Attempts to resolve the class that defined a method. @@ -1104,9 +1079,11 @@ def get_defining_class(meth: Callable) -> Optional[Type]: if isinstance(meth, functools.partial): return get_defining_class(meth.func) if inspect.ismethod(meth) or ( - inspect.isbuiltin(meth) and getattr(meth, '__self__') is not None and getattr(meth.__self__, '__class__') + inspect.isbuiltin(meth) + and getattr(meth, '__self__') is not None + and getattr(meth.__self__, '__class__') # type: ignore[attr-defined] ): - for cls in inspect.getmro(meth.__self__.__class__): + for cls in inspect.getmro(meth.__self__.__class__): # type: ignore[attr-defined] if meth.__name__ in cls.__dict__: return cls meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing @@ -1114,7 +1091,7 @@ def get_defining_class(meth: Callable) -> Optional[Type]: cls = getattr(inspect.getmodule(meth), meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]) if isinstance(cls, type): return cls - return getattr(meth, '__objclass__', None) # handle special descriptor objects + return cast(type, getattr(meth, '__objclass__', None)) # handle special descriptor objects class CompletionMode(Enum): |