diff options
author | Takeshi KOMIYA <i.tkomiya@gmail.com> | 2019-05-30 01:19:21 +0900 |
---|---|---|
committer | Takeshi KOMIYA <i.tkomiya@gmail.com> | 2019-06-02 17:51:20 +0900 |
commit | afbf6d811dd4fc514d63e9dc3d6bee78d97b8a0c (patch) | |
tree | bc5307c546e47262914ef8c25ab64de87136e1a3 | |
parent | 27dd8367c65c4313d499d945e7a2804865a1754a (diff) | |
download | sphinx-git-afbf6d811dd4fc514d63e9dc3d6bee78d97b8a0c.tar.gz |
Migrate to py3 style type annotation: sphinx.util
-rw-r--r-- | sphinx/util/__init__.py | 169 | ||||
-rw-r--r-- | sphinx/util/typing.py | 3 | ||||
-rw-r--r-- | tests/test_autodoc.py | 7 |
3 files changed, 71 insertions, 108 deletions
diff --git a/sphinx/util/__init__.py b/sphinx/util/__init__.py index 2ebae8768..66c53b37e 100644 --- a/sphinx/util/__init__.py +++ b/sphinx/util/__init__.py @@ -24,6 +24,9 @@ from datetime import datetime from hashlib import md5 from os import path from time import mktime, strptime +from typing import ( + Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type +) from urllib.parse import urlsplit, urlunsplit, quote_plus, parse_qsl, urlencode from docutils.utils import relative_path @@ -34,6 +37,7 @@ from sphinx.locale import __ from sphinx.util import logging from sphinx.util.console import strip_colors, colorize, bold, term_width_line # type: ignore from sphinx.util.fileutil import copy_asset_file +from sphinx.util.typing import PathMatcher from sphinx.util import smartypants # noqa # import other utilities; partly for backwards compatibility, so don't @@ -46,10 +50,11 @@ from sphinx.util.nodes import ( # noqa caption_ref_re) from sphinx.util.matching import patfilter # noqa + if False: # For type annotation - from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type, Union # NOQA - + from sphinx.application import Sphinx + from sphinx.builders import Builder logger = logging.getLogger(__name__) @@ -60,21 +65,19 @@ url_re = re.compile(r'(?P<schema>.+)://.*') # type: Pattern # High-level utility functions. -def docname_join(basedocname, docname): - # type: (str, str) -> str +def docname_join(basedocname: str, docname: str) -> str: return posixpath.normpath( posixpath.join('/' + basedocname, '..', docname))[1:] -def path_stabilize(filepath): - # type: (str) -> str +def path_stabilize(filepath: str) -> str: "normalize path separater and unicode string" newpath = filepath.replace(os.path.sep, SEP) return unicodedata.normalize('NFC', newpath) -def get_matching_files(dirname, exclude_matchers=()): - # type: (str, Tuple[Callable[[str], bool], ...]) -> Iterable[str] +def get_matching_files(dirname: str, + exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]: # NOQA """Get all file names in a directory, recursively. Exclude files and dirs matching some matcher in *exclude_matchers*. @@ -100,8 +103,8 @@ def get_matching_files(dirname, exclude_matchers=()): yield filename -def get_matching_docs(dirname, suffixes, exclude_matchers=()): - # type: (str, List[str], Tuple[Callable[[str], bool], ...]) -> Iterable[str] # NOQA +def get_matching_docs(dirname: str, suffixes: List[str], + exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]: """Get all file names (without suffixes) matching a suffix in a directory, recursively. @@ -123,12 +126,10 @@ class FilenameUniqDict(dict): interpreted as filenames, and keeps track of a set of docnames they appear in. Used for images and downloadable files in the environment. """ - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._existing = set() # type: Set[str] - def add_file(self, docname, newfile): - # type: (str, str) -> str + def add_file(self, docname: str, newfile: str) -> str: if newfile in self: self[newfile][0].add(docname) return self[newfile][1] @@ -142,26 +143,22 @@ class FilenameUniqDict(dict): self._existing.add(uniquename) return uniquename - def purge_doc(self, docname): - # type: (str) -> None + def purge_doc(self, docname: str) -> None: for filename, (docs, unique) in list(self.items()): docs.discard(docname) if not docs: del self[filename] self._existing.discard(unique) - def merge_other(self, docnames, other): - # type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None + def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None: for filename, (docs, unique) in other.items(): for doc in docs & set(docnames): self.add_file(doc, filename) - def __getstate__(self): - # type: () -> Set[str] + def __getstate__(self) -> Set[str]: return self._existing - def __setstate__(self, state): - # type: (Set[str]) -> None + def __setstate__(self, state: Set[str]) -> None: self._existing = state @@ -172,8 +169,7 @@ class DownloadFiles(dict): Hence don't hack this directly. """ - def add_file(self, docname, filename): - # type: (str, str) -> None + def add_file(self, docname: str, filename: str) -> None: if filename not in self: digest = md5(filename.encode()).hexdigest() dest = '%s/%s' % (digest, os.path.basename(filename)) @@ -182,23 +178,20 @@ class DownloadFiles(dict): self[filename][0].add(docname) return self[filename][1] - def purge_doc(self, docname): - # type: (str) -> None + def purge_doc(self, docname: str) -> None: for filename, (docs, dest) in list(self.items()): docs.discard(docname) if not docs: del self[filename] - def merge_other(self, docnames, other): - # type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None + def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None: for filename, (docs, dest) in other.items(): for docname in docs & set(docnames): self.add_file(docname, filename) -def copy_static_entry(source, targetdir, builder, context={}, - exclude_matchers=(), level=0): - # type: (str, str, Any, Dict, Tuple[Callable, ...], int) -> None +def copy_static_entry(source: str, targetdir: str, builder: "Builder", context: Dict = {}, + exclude_matchers: Tuple[PathMatcher, ...] = (), level: int = 0) -> None: """[DEPRECATED] Copy a HTML builder static_path entry from source to targetdir. Handles all possible cases of files, directories and subdirectories. @@ -237,8 +230,7 @@ _DEBUG_HEADER = '''\ ''' -def save_traceback(app): - # type: (Any) -> str +def save_traceback(app: "Sphinx") -> str: """Save the current exception's traceback in a temporary file.""" import sphinx import jinja2 @@ -273,8 +265,7 @@ def save_traceback(app): return path -def get_module_source(modname): - # type: (str) -> Tuple[str, str] +def get_module_source(modname: str) -> Tuple[str, str]: """Try to find the source code for a module. Can return ('file', 'filename') in which case the source is in the given @@ -321,8 +312,7 @@ def get_module_source(modname): return 'file', filename -def get_full_modname(modname, attribute): - # type: (str, str) -> str +def get_full_modname(modname: str, attribute: str) -> str: if modname is None: # Prevents a TypeError: if the last getattr() call will return None # then it's better to return it directly @@ -344,8 +334,7 @@ def get_full_modname(modname, attribute): _coding_re = re.compile(r'coding[:=]\s*([-\w.]+)') -def detect_encoding(readline): - # type: (Callable[[], bytes]) -> str +def detect_encoding(readline: Callable[[], bytes]) -> str: """Like tokenize.detect_encoding() from Py3k, but a bit simplified.""" def read_or_stop(): @@ -401,12 +390,10 @@ def detect_encoding(readline): class UnicodeDecodeErrorHandler: """Custom error handler for open() that warns and replaces.""" - def __init__(self, docname): - # type: (str) -> None + def __init__(self, docname: str) -> None: self.docname = docname - def __call__(self, error): - # type: (UnicodeDecodeError) -> Tuple[Union[str, str], int] + def __call__(self, error: UnicodeDecodeError) -> Tuple[str, int]: linestart = error.object.rfind(b'\n', 0, error.start) lineend = error.object.find(b'\n', error.start) if lineend == -1: @@ -426,26 +413,22 @@ class Tee: """ File-like object writing to two streams. """ - def __init__(self, stream1, stream2): - # type: (IO, IO) -> None + def __init__(self, stream1: IO, stream2: IO) -> None: self.stream1 = stream1 self.stream2 = stream2 - def write(self, text): - # type: (str) -> None + def write(self, text: str) -> None: self.stream1.write(text) self.stream2.write(text) - def flush(self): - # type: () -> None + def flush(self) -> None: if hasattr(self.stream1, 'flush'): self.stream1.flush() if hasattr(self.stream2, 'flush'): self.stream2.flush() -def parselinenos(spec, total): - # type: (str, int) -> List[int] +def parselinenos(spec: str, total: int) -> List[int]: """Parse a line number spec (such as "1,2,4-6") and return a list of wanted line numbers. """ @@ -472,8 +455,7 @@ def parselinenos(spec, total): return items -def force_decode(string, encoding): - # type: (str, str) -> str +def force_decode(string: str, encoding: str) -> str: """Forcibly get a unicode string out of a bytestring.""" warnings.warn('force_decode() is deprecated.', RemovedInSphinx40Warning, stacklevel=2) @@ -491,26 +473,22 @@ def force_decode(string, encoding): class attrdict(dict): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) warnings.warn('The attrdict class is deprecated.', RemovedInSphinx40Warning, stacklevel=2) - def __getattr__(self, key): - # type: (str) -> str + def __getattr__(self, key: str) -> str: return self[key] - def __setattr__(self, key, val): - # type: (str, str) -> None + def __setattr__(self, key: str, val: str) -> None: self[key] = val - def __delattr__(self, key): - # type: (str) -> None + def __delattr__(self, key: str) -> None: del self[key] -def rpartition(s, t): - # type: (str, str) -> Tuple[str, str] +def rpartition(s: str, t: str) -> Tuple[str, str]: """Similar to str.rpartition from 2.5, but doesn't return the separator.""" i = s.rfind(t) if i != -1: @@ -518,8 +496,7 @@ def rpartition(s, t): return '', s -def split_into(n, type, value): - # type: (int, str, str) -> List[str] +def split_into(n: int, type: str, value: str) -> List[str]: """Split an index entry into a given number of parts at semicolons.""" parts = [x.strip() for x in value.split(';', n - 1)] if sum(1 for part in parts if part) < n: @@ -527,8 +504,7 @@ def split_into(n, type, value): return parts -def split_index_msg(type, value): - # type: (str, str) -> List[str] +def split_index_msg(type: str, value: str) -> List[str]: # new entry types must be listed in directives/other.py! if type == 'single': try: @@ -549,8 +525,7 @@ def split_index_msg(type, value): return result -def format_exception_cut_frames(x=1): - # type: (int) -> str +def format_exception_cut_frames(x: int = 1) -> str: """Format an exception with traceback, but only the last x frames.""" typ, val, tb = sys.exc_info() # res = ['Traceback (most recent call last):\n'] @@ -566,19 +541,16 @@ class PeekableIterator: An iterator which wraps any iterable and makes it possible to peek to see what's the next item. """ - def __init__(self, iterable): - # type: (Iterable) -> None + def __init__(self, iterable: Iterable) -> None: self.remaining = deque() # type: deque self._iterator = iter(iterable) warnings.warn('PeekableIterator is deprecated.', RemovedInSphinx40Warning, stacklevel=2) - def __iter__(self): - # type: () -> PeekableIterator + def __iter__(self) -> "PeekableIterator": return self - def __next__(self): - # type: () -> Any + def __next__(self) -> Any: """Return the next item from the iterator.""" if self.remaining: return self.remaining.popleft() @@ -586,23 +558,20 @@ class PeekableIterator: next = __next__ # Python 2 compatibility - def push(self, item): - # type: (Any) -> None + def push(self, item: Any) -> None: """Push the `item` on the internal stack, it will be returned on the next :meth:`next` call. """ self.remaining.append(item) - def peek(self): - # type: () -> Any + def peek(self) -> Any: """Return the next item without changing the state of the iterator.""" item = next(self) self.push(item) return item -def import_object(objname, source=None): - # type: (str, str) -> Any +def import_object(objname: str, source: str = None) -> Any: """Import python object by qualname.""" try: objpath = objname.split('.') @@ -625,8 +594,7 @@ def import_object(objname, source=None): raise ExtensionError('Could not import %s' % objname, exc) -def encode_uri(uri): - # type: (str) -> str +def encode_uri(uri: str) -> str: split = list(urlsplit(uri)) split[1] = split[1].encode('idna').decode('ascii') split[2] = quote_plus(split[2].encode(), '/') @@ -635,8 +603,7 @@ def encode_uri(uri): return urlunsplit(split) -def display_chunk(chunk): - # type: (Any) -> str +def display_chunk(chunk: Any) -> str: if isinstance(chunk, (list, tuple)): if len(chunk) == 1: return str(chunk[0]) @@ -644,8 +611,8 @@ def display_chunk(chunk): return str(chunk) -def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=display_chunk): - # type: (Iterable, str, str, Callable[[Any], str]) -> Iterator +def old_status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen", + stringify_func: Callable[[Any], str] = display_chunk) -> Iterator: l = 0 for item in iterable: if l == 0: @@ -659,9 +626,9 @@ def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=dis # new version with progress info -def status_iterator(iterable, summary, color="darkgreen", length=0, verbosity=0, - stringify_func=display_chunk): - # type: (Iterable, str, str, int, int, Callable[[Any], str]) -> Iterable +def status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen", + length: int = 0, verbosity: int = 0, + stringify_func: Callable[[Any], str] = display_chunk) -> Iterable: if length == 0: yield from old_status_iterator(iterable, summary, color, stringify_func) return @@ -685,16 +652,13 @@ class SkipProgressMessage(Exception): class progress_message: - def __init__(self, message): - # type: (str) -> None + def __init__(self, message: str) -> None: self.message = message - def __enter__(self): - # type: () -> None + def __enter__(self) -> None: logger.info(bold(self.message + '... '), nonl=True) - def __exit__(self, exc_type, exc_value, traceback): - # type: (Type[Exception], Exception, Any) -> bool + def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: Any) -> bool: # NOQA if isinstance(exc_value, SkipProgressMessage): logger.info(__('skipped')) if exc_value.args: @@ -707,8 +671,7 @@ class progress_message: return False - def __call__(self, f): - # type: (Callable) -> Callable + def __call__(self, f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args, **kwargs): with self: @@ -717,8 +680,7 @@ class progress_message: return wrapper -def epoch_to_rfc1123(epoch): - # type: (float) -> str +def epoch_to_rfc1123(epoch: float) -> str: """Convert datetime format epoch to RFC1123.""" from babel.dates import format_datetime @@ -727,13 +689,11 @@ def epoch_to_rfc1123(epoch): return format_datetime(dt, fmt, locale='en') + ' GMT' -def rfc1123_to_epoch(rfc1123): - # type: (str) -> float +def rfc1123_to_epoch(rfc1123: str) -> float: return mktime(strptime(rfc1123, '%a, %d %b %Y %H:%M:%S %Z')) -def xmlname_checker(): - # type: () -> Pattern +def xmlname_checker() -> Pattern: # https://www.w3.org/TR/REC-xml/#NT-Name name_start_chars = [ ':', ['A', 'Z'], '_', ['a', 'z'], ['\u00C0', '\u00D6'], @@ -747,8 +707,7 @@ def xmlname_checker(): ['\u203F', '\u2040'] ] - def convert(entries, splitter='|'): - # type: (Any, str) -> str + def convert(entries: Any, splitter: str = '|') -> str: results = [] for entry in entries: if isinstance(entry, list): diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index 78e8fe61f..77724d38b 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -23,6 +23,9 @@ TextlikeNode = Union[nodes.Text, nodes.TextElement] # type of None NoneType = type(None) +# path matcher +PathMatcher = Callable[[str], bool] + # common role functions RoleFunction = Callable[[str, str, str, int, Inliner, Dict, List[str]], Tuple[List[nodes.Node], List[nodes.system_message]]] diff --git a/tests/test_autodoc.py b/tests/test_autodoc.py index 518d23e8c..ad6a56d33 100644 --- a/tests/test_autodoc.py +++ b/tests/test_autodoc.py @@ -793,7 +793,7 @@ def test_autodoc_imported_members(app): "imported-members": None, "ignore-module-all": None} actual = do_autodoc(app, 'module', 'target', options) - assert '.. py:function:: save_traceback(app)' in actual + assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual @pytest.mark.sphinx('html', testroot='ext-autodoc') @@ -1795,7 +1795,7 @@ def test_autodoc_default_options(app): actual = do_autodoc(app, 'class', 'target.CustomIter') assert ' .. py:method:: target.CustomIter' not in actual actual = do_autodoc(app, 'module', 'target') - assert '.. py:function:: save_traceback(app)' not in actual + assert '.. py:function:: save_traceback(app: Sphinx) -> str' not in actual # with :members: app.config.autodoc_default_options = {'members': None} @@ -1866,7 +1866,8 @@ def test_autodoc_default_options(app): 'ignore-module-all': None, } actual = do_autodoc(app, 'module', 'target') - assert '.. py:function:: save_traceback(app)' in actual + print('\n'.join(actual)) + assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual @pytest.mark.sphinx('html', testroot='ext-autodoc') |