summaryrefslogtreecommitdiff
path: root/sphinx/util/parallel.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/util/parallel.py')
-rw-r--r--sphinx/util/parallel.py32
1 files changed, 10 insertions, 22 deletions
diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py
index 013dc3071..2d519a8d3 100644
--- a/sphinx/util/parallel.py
+++ b/sphinx/util/parallel.py
@@ -12,6 +12,7 @@ import os
import time
import traceback
from math import sqrt
+from typing import Any, Callable, Dict, List, Sequence
try:
import multiprocessing
@@ -21,10 +22,6 @@ except ImportError:
from sphinx.errors import SphinxParallelError
from sphinx.util import logging
-if False:
- # For type annotation
- from typing import Any, Callable, Dict, List, Sequence # NOQA
-
logger = logging.getLogger(__name__)
@@ -35,12 +32,10 @@ parallel_available = multiprocessing and (os.name == 'posix')
class SerialTasks:
"""Has the same interface as ParallelTasks, but executes tasks directly."""
- def __init__(self, nproc=1):
- # type: (int) -> None
+ def __init__(self, nproc: int = 1) -> None:
pass
- def add_task(self, task_func, arg=None, result_func=None):
- # type: (Callable, Any, Callable) -> None
+ def add_task(self, task_func: Callable, arg: Any = None, result_func: Callable = None) -> None: # NOQA
if arg is not None:
res = task_func(arg)
else:
@@ -48,16 +43,14 @@ class SerialTasks:
if result_func:
result_func(res)
- def join(self):
- # type: () -> None
+ def join(self) -> None:
pass
class ParallelTasks:
"""Executes *nproc* tasks in parallel after forking."""
- def __init__(self, nproc):
- # type: (int) -> None
+ def __init__(self, nproc: int) -> None:
self.nproc = nproc
# (optional) function performed by each task on the result of main task
self._result_funcs = {} # type: Dict[int, Callable]
@@ -74,8 +67,7 @@ class ParallelTasks:
# task number of each subprocess
self._taskid = 0
- def _process(self, pipe, func, arg):
- # type: (Any, Callable, Any) -> None
+ def _process(self, pipe: Any, func: Callable, arg: Any) -> None:
try:
collector = logging.LogCollector()
with collector.collect():
@@ -91,8 +83,7 @@ class ParallelTasks:
logging.convert_serializable(collector.logs)
pipe.send((failed, collector.logs, ret))
- def add_task(self, task_func, arg=None, result_func=None):
- # type: (Callable, Any, Callable) -> None
+ def add_task(self, task_func: Callable, arg: Any = None, result_func: Callable = None) -> None: # NOQA
tid = self._taskid
self._taskid += 1
self._result_funcs[tid] = result_func or (lambda arg, result: None)
@@ -104,13 +95,11 @@ class ParallelTasks:
self._precvsWaiting[tid] = precv
self._join_one()
- def join(self):
- # type: () -> None
+ def join(self) -> None:
while self._pworking:
self._join_one()
- def _join_one(self):
- # type: () -> None
+ def _join_one(self) -> None:
for tid, pipe in self._precvs.items():
if pipe.poll():
exc, logs, result = pipe.recv()
@@ -132,8 +121,7 @@ class ParallelTasks:
self._pworking += 1
-def make_chunks(arguments, nproc, maxbatch=10):
- # type: (Sequence[str], int, int) -> List[Any]
+def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> List[Any]:
# determine how many documents to read in one go
nargs = len(arguments)
chunksize = nargs // nproc