diff options
Diffstat (limited to 'sphinx/util/parallel.py')
-rw-r--r-- | sphinx/util/parallel.py | 32 |
1 files changed, 10 insertions, 22 deletions
diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py index 013dc3071..2d519a8d3 100644 --- a/sphinx/util/parallel.py +++ b/sphinx/util/parallel.py @@ -12,6 +12,7 @@ import os import time import traceback from math import sqrt +from typing import Any, Callable, Dict, List, Sequence try: import multiprocessing @@ -21,10 +22,6 @@ except ImportError: from sphinx.errors import SphinxParallelError from sphinx.util import logging -if False: - # For type annotation - from typing import Any, Callable, Dict, List, Sequence # NOQA - logger = logging.getLogger(__name__) @@ -35,12 +32,10 @@ parallel_available = multiprocessing and (os.name == 'posix') class SerialTasks: """Has the same interface as ParallelTasks, but executes tasks directly.""" - def __init__(self, nproc=1): - # type: (int) -> None + def __init__(self, nproc: int = 1) -> None: pass - def add_task(self, task_func, arg=None, result_func=None): - # type: (Callable, Any, Callable) -> None + def add_task(self, task_func: Callable, arg: Any = None, result_func: Callable = None) -> None: # NOQA if arg is not None: res = task_func(arg) else: @@ -48,16 +43,14 @@ class SerialTasks: if result_func: result_func(res) - def join(self): - # type: () -> None + def join(self) -> None: pass class ParallelTasks: """Executes *nproc* tasks in parallel after forking.""" - def __init__(self, nproc): - # type: (int) -> None + def __init__(self, nproc: int) -> None: self.nproc = nproc # (optional) function performed by each task on the result of main task self._result_funcs = {} # type: Dict[int, Callable] @@ -74,8 +67,7 @@ class ParallelTasks: # task number of each subprocess self._taskid = 0 - def _process(self, pipe, func, arg): - # type: (Any, Callable, Any) -> None + def _process(self, pipe: Any, func: Callable, arg: Any) -> None: try: collector = logging.LogCollector() with collector.collect(): @@ -91,8 +83,7 @@ class ParallelTasks: logging.convert_serializable(collector.logs) pipe.send((failed, collector.logs, ret)) - def add_task(self, task_func, arg=None, result_func=None): - # type: (Callable, Any, Callable) -> None + def add_task(self, task_func: Callable, arg: Any = None, result_func: Callable = None) -> None: # NOQA tid = self._taskid self._taskid += 1 self._result_funcs[tid] = result_func or (lambda arg, result: None) @@ -104,13 +95,11 @@ class ParallelTasks: self._precvsWaiting[tid] = precv self._join_one() - def join(self): - # type: () -> None + def join(self) -> None: while self._pworking: self._join_one() - def _join_one(self): - # type: () -> None + def _join_one(self) -> None: for tid, pipe in self._precvs.items(): if pipe.poll(): exc, logs, result = pipe.recv() @@ -132,8 +121,7 @@ class ParallelTasks: self._pworking += 1 -def make_chunks(arguments, nproc, maxbatch=10): - # type: (Sequence[str], int, int) -> List[Any] +def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> List[Any]: # determine how many documents to read in one go nargs = len(arguments) chunksize = nargs // nproc |