summaryrefslogtreecommitdiff
path: root/coverage/misc.py
diff options
context:
space:
mode:
Diffstat (limited to 'coverage/misc.py')
-rw-r--r--coverage/misc.py76
1 files changed, 46 insertions, 30 deletions
diff --git a/coverage/misc.py b/coverage/misc.py
index bd1767ca..a2ac2fed 100644
--- a/coverage/misc.py
+++ b/coverage/misc.py
@@ -3,7 +3,10 @@
"""Miscellaneous stuff for coverage.py."""
+from __future__ import annotations
+
import contextlib
+import datetime
import errno
import hashlib
import importlib
@@ -16,20 +19,25 @@ import re
import sys
import types
-from typing import Iterable
+from types import ModuleType
+from typing import (
+ Any, Callable, Dict, Generator, IO, Iterable, List, Mapping, Optional,
+ Tuple, TypeVar, Union,
+)
from coverage import env
from coverage.exceptions import CoverageException
+from coverage.types import TArc
# In 6.0, the exceptions moved from misc.py to exceptions.py. But a number of
# other packages were importing the exceptions from misc, so import them here.
# pylint: disable=unused-wildcard-import
from coverage.exceptions import * # pylint: disable=wildcard-import
-ISOLATED_MODULES = {}
+ISOLATED_MODULES: Dict[ModuleType, ModuleType] = {}
-def isolate_module(mod):
+def isolate_module(mod: ModuleType) -> ModuleType:
"""Copy a module so that we are isolated from aggressive mocking.
If a test suite mocks os.path.exists (for example), and then we need to use
@@ -52,10 +60,10 @@ os = isolate_module(os)
class SysModuleSaver:
"""Saves the contents of sys.modules, and removes new modules later."""
- def __init__(self):
+ def __init__(self) -> None:
self.old_modules = set(sys.modules)
- def restore(self):
+ def restore(self) -> None:
"""Remove any modules imported since this object started."""
new_modules = set(sys.modules) - self.old_modules
for m in new_modules:
@@ -63,7 +71,7 @@ class SysModuleSaver:
@contextlib.contextmanager
-def sys_modules_saved():
+def sys_modules_saved() -> Generator[None, None, None]:
"""A context manager to remove any modules imported during a block."""
saver = SysModuleSaver()
try:
@@ -72,7 +80,7 @@ def sys_modules_saved():
saver.restore()
-def import_third_party(modname):
+def import_third_party(modname: str) -> Tuple[ModuleType, bool]:
"""Import a third-party module we need, but might not be installed.
This also cleans out the module after the import, so that coverage won't
@@ -95,7 +103,7 @@ def import_third_party(modname):
return sys, False
-def nice_pair(pair):
+def nice_pair(pair: TArc) -> str:
"""Make a nice string representation of a pair of numbers.
If the numbers are equal, just return the number, otherwise return the pair
@@ -109,7 +117,10 @@ def nice_pair(pair):
return "%d-%d" % (start, end)
-def expensive(fn):
+TSelf = TypeVar("TSelf")
+TRetVal = TypeVar("TRetVal")
+
+def expensive(fn: Callable[[TSelf], TRetVal]) -> Callable[[TSelf], TRetVal]:
"""A decorator to indicate that a method shouldn't be called more than once.
Normally, this does nothing. During testing, this raises an exception if
@@ -119,7 +130,7 @@ def expensive(fn):
if env.TESTING:
attr = "_once_" + fn.__name__
- def _wrapper(self):
+ def _wrapper(self: TSelf) -> TRetVal:
if hasattr(self, attr):
raise AssertionError(f"Shouldn't have called {fn.__name__} more than once")
setattr(self, attr, True)
@@ -129,7 +140,7 @@ def expensive(fn):
return fn # pragma: not testing
-def bool_or_none(b):
+def bool_or_none(b: Any) -> Optional[bool]:
"""Return bool(b), but preserve None."""
if b is None:
return None
@@ -146,7 +157,7 @@ def join_regex(regexes: Iterable[str]) -> str:
return "|".join(f"(?:{r})" for r in regexes)
-def file_be_gone(path):
+def file_be_gone(path: str) -> None:
"""Remove a file, and don't get annoyed if it doesn't exist."""
try:
os.remove(path)
@@ -155,7 +166,7 @@ def file_be_gone(path):
raise
-def ensure_dir(directory):
+def ensure_dir(directory: str) -> None:
"""Make sure the directory exists.
If `directory` is None or empty, do nothing.
@@ -164,12 +175,12 @@ def ensure_dir(directory):
os.makedirs(directory, exist_ok=True)
-def ensure_dir_for_file(path):
+def ensure_dir_for_file(path: str) -> None:
"""Make sure the directory for the path exists."""
ensure_dir(os.path.dirname(path))
-def output_encoding(outfile=None):
+def output_encoding(outfile: Optional[IO[str]]=None) -> str:
"""Determine the encoding to use for output written to `outfile` or stdout."""
if outfile is None:
outfile = sys.stdout
@@ -183,10 +194,10 @@ def output_encoding(outfile=None):
class Hasher:
"""Hashes Python data for fingerprinting."""
- def __init__(self):
+ def __init__(self) -> None:
self.hash = hashlib.new("sha3_256")
- def update(self, v):
+ def update(self, v: Any) -> None:
"""Add `v` to the hash, recursively if needed."""
self.hash.update(str(type(v)).encode("utf-8"))
if isinstance(v, str):
@@ -216,12 +227,12 @@ class Hasher:
self.update(a)
self.hash.update(b'.')
- def hexdigest(self):
+ def hexdigest(self) -> str:
"""Retrieve the hex digest of the hash."""
return self.hash.hexdigest()[:32]
-def _needs_to_implement(that, func_name):
+def _needs_to_implement(that: Any, func_name: str) -> None:
"""Helper to raise NotImplementedError in interface stubs."""
if hasattr(that, "_coverage_plugin_name"):
thing = "Plugin"
@@ -243,14 +254,14 @@ class DefaultValue:
and Sphinx output.
"""
- def __init__(self, display_as):
+ def __init__(self, display_as: str) -> None:
self.display_as = display_as
- def __repr__(self):
+ def __repr__(self) -> str:
return self.display_as
-def substitute_variables(text, variables):
+def substitute_variables(text: str, variables: Mapping[str, str]) -> str:
"""Substitute ``${VAR}`` variables in `text` with their values.
Variables in the text can take a number of shell-inspired forms::
@@ -283,7 +294,7 @@ def substitute_variables(text, variables):
dollar_groups = ('dollar', 'word1', 'word2')
- def dollar_replace(match):
+ def dollar_replace(match: re.Match[str]) -> str:
"""Called for each $replacement."""
# Only one of the dollar_groups will have matched, just get its text.
word = next(g for g in match.group(*dollar_groups) if g) # pragma: always breaks
@@ -301,13 +312,13 @@ def substitute_variables(text, variables):
return text
-def format_local_datetime(dt):
+def format_local_datetime(dt: datetime.datetime) -> str:
"""Return a string with local timezone representing the date.
"""
return dt.astimezone().strftime('%Y-%m-%d %H:%M %z')
-def import_local_file(modname, modfile=None):
+def import_local_file(modname: str, modfile: Optional[str]=None) -> ModuleType:
"""Import a local file as a module.
Opens a file in the current directory named `modname`.py, imports it
@@ -318,18 +329,20 @@ def import_local_file(modname, modfile=None):
if modfile is None:
modfile = modname + '.py'
spec = importlib.util.spec_from_file_location(modname, modfile)
+ assert spec is not None
mod = importlib.util.module_from_spec(spec)
sys.modules[modname] = mod
+ assert spec.loader is not None
spec.loader.exec_module(mod)
return mod
-def _human_key(s):
+def _human_key(s: str) -> List[Union[str, int]]:
"""Turn a string into a list of string and number chunks.
"z23a" -> ["z", 23, "a"]
"""
- def tryint(s):
+ def tryint(s: str) -> Union[str, int]:
"""If `s` is a number, return an int, else `s` unchanged."""
try:
return int(s)
@@ -338,7 +351,7 @@ def _human_key(s):
return [tryint(c) for c in re.split(r"(\d+)", s)]
-def human_sorted(strings):
+def human_sorted(strings: Iterable[str]) -> List[str]:
"""Sort the given iterable of strings the way that humans expect.
Numeric components in the strings are sorted as numbers.
@@ -348,7 +361,10 @@ def human_sorted(strings):
"""
return sorted(strings, key=_human_key)
-def human_sorted_items(items, reverse=False):
+def human_sorted_items(
+ items: Iterable[Tuple[str, Any]],
+ reverse: bool=False,
+) -> List[Tuple[str, Any]]:
"""Sort (string, ...) items the way humans expect.
The elements of `items` can be any tuple/list. They'll be sorted by the
@@ -359,7 +375,7 @@ def human_sorted_items(items, reverse=False):
return sorted(items, key=lambda item: (_human_key(item[0]), *item[1:]), reverse=reverse)
-def plural(n, thing="", things=""):
+def plural(n: int, thing: str="", things: str="") -> str:
"""Pluralize a word.
If n is 1, return thing. Otherwise return things, or thing+s.