diff options
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r-- | sphinx/pycode/ast.py | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py index 65534f958..f541ec0a9 100644 --- a/sphinx/pycode/ast.py +++ b/sphinx/pycode/ast.py @@ -9,7 +9,7 @@ """ import sys -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Type, overload if sys.version_info > (3, 8): import ast @@ -21,7 +21,7 @@ else: import ast # type: ignore -OPERATORS = { +OPERATORS: Dict[Type[ast.AST], str] = { ast.Add: "+", ast.And: "and", ast.BitAnd: "&", @@ -41,7 +41,7 @@ OPERATORS = { ast.Sub: "-", ast.UAdd: "+", ast.USub: "-", -} # type: Dict[Type[ast.AST], str] +} def parse(code: str, mode: str = 'exec') -> "ast.AST": @@ -62,6 +62,16 @@ def parse(code: str, mode: str = 'exec') -> "ast.AST": return ast.parse(code, mode=mode) +@overload +def unparse(node: None, code: str = '') -> None: + ... + + +@overload +def unparse(node: ast.AST, code: str = '') -> str: + ... + + def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]: """Unparse an AST to string.""" if node is None: @@ -98,7 +108,7 @@ class _UnparseVisitor(ast.NodeVisitor): return name def visit_arguments(self, node: ast.arguments) -> str: - defaults = list(node.defaults) + defaults: List[Optional[ast.expr]] = list(node.defaults) positionals = len(node.args) posonlyargs = 0 if hasattr(node, "posonlyargs"): # for py38+ @@ -107,11 +117,11 @@ class _UnparseVisitor(ast.NodeVisitor): for _ in range(len(defaults), positionals): defaults.insert(0, None) - kw_defaults = list(node.kw_defaults) + kw_defaults: List[Optional[ast.expr]] = list(node.kw_defaults) for _ in range(len(kw_defaults), len(node.kwonlyargs)): kw_defaults.insert(0, None) - args = [] # type: List[str] + args: List[str] = [] if hasattr(node, "posonlyargs"): # for py38+ for i, arg in enumerate(node.posonlyargs): # type: ignore args.append(self._visit_arg_with_default(arg, defaults[i])) @@ -150,6 +160,17 @@ class _UnparseVisitor(ast.NodeVisitor): ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords]) return "%s(%s)" % (self.visit(node.func), ", ".join(args)) + def visit_Constant(self, node: ast.Constant) -> str: # type: ignore + if node.value is Ellipsis: + return "..." + elif isinstance(node.value, (int, float, complex)): + if self.code and sys.version_info > (3, 8): + return ast.get_source_segment(self.code, node) # type: ignore + else: + return repr(node.value) + else: + return repr(node.value) + def visit_Dict(self, node: ast.Dict) -> str: keys = (self.visit(k) for k in node.keys) values = (self.visit(v) for v in node.values) @@ -197,18 +218,6 @@ class _UnparseVisitor(ast.NodeVisitor): else: return "()" - if sys.version_info >= (3, 6): - def visit_Constant(self, node: ast.Constant) -> str: - if node.value is Ellipsis: - return "..." - elif isinstance(node.value, (int, float, complex)): - if self.code and sys.version_info > (3, 8): - return ast.get_source_segment(self.code, node) - else: - return repr(node.value) - else: - return repr(node.value) - if sys.version_info < (3, 8): # these ast nodes were deprecated in python 3.8 def visit_Bytes(self, node: ast.Bytes) -> str: |