summaryrefslogtreecommitdiff
path: root/cmd2/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'cmd2/utils.py')
-rw-r--r--cmd2/utils.py85
1 files changed, 31 insertions, 54 deletions
diff --git a/cmd2/utils.py b/cmd2/utils.py
index 1008cb86..717d73b4 100644
--- a/cmd2/utils.py
+++ b/cmd2/utils.py
@@ -3,7 +3,6 @@
import argparse
import collections
-import collections.abc as collections_abc
import functools
import glob
import inspect
@@ -19,24 +18,27 @@ from enum import (
)
from typing import (
IO,
+ TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
- NamedTuple,
Optional,
TextIO,
Type,
- TYPE_CHECKING,
+ TypeVar,
Union,
+ cast,
)
from . import (
constants,
)
+
if TYPE_CHECKING: # pragma: no cover
- import cmd2
+ import cmd2 # noqa: F401
+
def is_quoted(arg: str) -> bool:
"""
@@ -100,15 +102,15 @@ class Settable:
def __init__(
self,
name: str,
- val_type: Callable,
+ val_type: Union[Type[Any], Callable[[Any], Any]],
description: str,
*,
settable_object: Optional[object] = None,
settable_attrib_name: Optional[str] = None,
- onchange_cb: Callable[[str, Any, Any], Any] = None,
- choices: Iterable = None,
- choices_provider: Optional[Callable] = None,
- completer: Optional[Callable] = None
+ onchange_cb: Optional[Callable[[str, Any, Any], Any]] = None,
+ choices: Optional[Iterable[Any]] = None,
+ choices_provider: Optional[Callable[[], List[str]]] = None,
+ completer: Optional[Callable[[str, str, int, int], List[str]]] = None
):
"""
Settable Initializer
@@ -174,36 +176,6 @@ class Settable:
return new_value
-def namedtuple_with_defaults(typename: str, field_names: Union[str, List[str]], default_values: collections_abc.Iterable = ()):
- """
- Convenience function for defining a namedtuple with default values
-
- From: https://stackoverflow.com/questions/11351032/namedtuple-and-default-values-for-optional-keyword-arguments
-
- Examples:
- >>> Node = namedtuple_with_defaults('Node', 'val left right')
- >>> Node()
- Node(val=None, left=None, right=None)
- >>> Node = namedtuple_with_defaults('Node', 'val left right', [1, 2, 3])
- >>> Node()
- Node(val=1, left=2, right=3)
- >>> Node = namedtuple_with_defaults('Node', 'val left right', {'right':7})
- >>> Node()
- Node(val=None, left=None, right=7)
- >>> Node(4)
- Node(val=4, left=None, right=7)
- """
- T: NamedTuple = collections.namedtuple(typename, field_names)
- # noinspection PyProtectedMember,PyUnresolvedReferences
- T.__new__.__defaults__ = (None,) * len(T._fields)
- if isinstance(default_values, collections_abc.Mapping):
- prototype = T(**default_values)
- else:
- prototype = T(*default_values)
- T.__new__.__defaults__ = tuple(prototype)
- return T
-
-
def is_text_file(file_path: str) -> bool:
"""Returns if a file contains only ASCII or UTF-8 encoded text.
@@ -241,13 +213,16 @@ def is_text_file(file_path: str) -> bool:
return valid_text_file
-def remove_duplicates(list_to_prune: List) -> List:
+_T = TypeVar('_T')
+
+
+def remove_duplicates(list_to_prune: List[_T]) -> List[_T]:
"""Removes duplicates from a list while preserving order of the items.
:param list_to_prune: the list being pruned of duplicates
:return: The pruned list
"""
- temp_dict = collections.OrderedDict()
+ temp_dict: collections.OrderedDict[_T, Any] = collections.OrderedDict()
for item in list_to_prune:
temp_dict[item] = None
@@ -405,7 +380,7 @@ def find_editor() -> Optional[str]:
return editor
-def files_from_glob_pattern(pattern: str, access=os.F_OK) -> List[str]:
+def files_from_glob_pattern(pattern: str, access: int = os.F_OK) -> List[str]:
"""Return a list of file paths based on a glob pattern.
Only files are returned, not directories, and optionally only files for which the user has a specified access to.
@@ -417,7 +392,7 @@ def files_from_glob_pattern(pattern: str, access=os.F_OK) -> List[str]:
return [f for f in glob.glob(pattern) if os.path.isfile(f) and os.access(f, access)]
-def files_from_glob_patterns(patterns: List[str], access=os.F_OK) -> List[str]:
+def files_from_glob_patterns(patterns: List[str], access: int = os.F_OK) -> List[str]:
"""Return a list of file paths based on a list of glob patterns.
Only files are returned, not directories, and optionally only files for which the user has a specified access to.
@@ -472,7 +447,7 @@ class StdSim:
Stores contents in internal buffer and optionally echos to the inner stream it is simulating.
"""
- def __init__(self, inner_stream, *, echo: bool = False, encoding: str = 'utf-8', errors: str = 'replace') -> None:
+ def __init__(self, inner_stream: TextIO, *, echo: bool = False, encoding: str = 'utf-8', errors: str = 'replace') -> None:
"""
StdSim Initializer
:param inner_stream: the wrapped stream. Should be a TextIO or StdSim instance.
@@ -540,11 +515,11 @@ class StdSim:
when running unit tests because pytest sets stdout to a pytest EncodedFile object.
"""
try:
- return self.inner_stream.line_buffering
+ return bool(self.inner_stream.line_buffering)
except AttributeError:
return False
- def __getattr__(self, item: str):
+ def __getattr__(self, item: str) -> Any:
if item in self.__dict__:
return self.__dict__[item]
else:
@@ -701,7 +676,7 @@ class ContextFlag:
def __enter__(self) -> None:
self.__count += 1
- def __exit__(self, *args) -> None:
+ def __exit__(self, *args: Any) -> None:
self.__count -= 1
if self.__count < 0:
raise ValueError("count has gone below 0")
@@ -1060,7 +1035,7 @@ def get_styles_in_text(text: str) -> Dict[int, str]:
return styles
-def categorize(func: Union[Callable, Iterable[Callable]], category: str) -> None:
+def categorize(func: Union[Callable[..., Any], Iterable[Callable[..., Any]]], category: str) -> None:
"""Categorize a function.
The help command output will group the passed function under the
@@ -1085,13 +1060,13 @@ def categorize(func: Union[Callable, Iterable[Callable]], category: str) -> None
for item in func:
setattr(item, constants.CMD_ATTR_HELP_CATEGORY, category)
else:
- if inspect.ismethod(func):
- setattr(func.__func__, constants.CMD_ATTR_HELP_CATEGORY, category)
+ if inspect.ismethod(func) and hasattr(func, '__func__'):
+ setattr(func.__func__, constants.CMD_ATTR_HELP_CATEGORY, category) # type: ignore[attr-defined]
else:
setattr(func, constants.CMD_ATTR_HELP_CATEGORY, category)
-def get_defining_class(meth: Callable) -> Optional[Type]:
+def get_defining_class(meth: Callable[..., Any]) -> Optional[Type[Any]]:
"""
Attempts to resolve the class that defined a method.
@@ -1104,9 +1079,11 @@ def get_defining_class(meth: Callable) -> Optional[Type]:
if isinstance(meth, functools.partial):
return get_defining_class(meth.func)
if inspect.ismethod(meth) or (
- inspect.isbuiltin(meth) and getattr(meth, '__self__') is not None and getattr(meth.__self__, '__class__')
+ inspect.isbuiltin(meth)
+ and getattr(meth, '__self__') is not None
+ and getattr(meth.__self__, '__class__') # type: ignore[attr-defined]
):
- for cls in inspect.getmro(meth.__self__.__class__):
+ for cls in inspect.getmro(meth.__self__.__class__): # type: ignore[attr-defined]
if meth.__name__ in cls.__dict__:
return cls
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing
@@ -1114,7 +1091,7 @@ def get_defining_class(meth: Callable) -> Optional[Type]:
cls = getattr(inspect.getmodule(meth), meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
if isinstance(cls, type):
return cls
- return getattr(meth, '__objclass__', None) # handle special descriptor objects
+ return cast(type, getattr(meth, '__objclass__', None)) # handle special descriptor objects
class CompletionMode(Enum):