summaryrefslogtreecommitdiff
path: root/sphinx/pycode/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r--sphinx/pycode/ast.py45
1 files changed, 27 insertions, 18 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py
index 65534f958..f541ec0a9 100644
--- a/sphinx/pycode/ast.py
+++ b/sphinx/pycode/ast.py
@@ -9,7 +9,7 @@
"""
import sys
-from typing import Dict, List, Optional, Type
+from typing import Dict, List, Optional, Type, overload
if sys.version_info > (3, 8):
import ast
@@ -21,7 +21,7 @@ else:
import ast # type: ignore
-OPERATORS = {
+OPERATORS: Dict[Type[ast.AST], str] = {
ast.Add: "+",
ast.And: "and",
ast.BitAnd: "&",
@@ -41,7 +41,7 @@ OPERATORS = {
ast.Sub: "-",
ast.UAdd: "+",
ast.USub: "-",
-} # type: Dict[Type[ast.AST], str]
+}
def parse(code: str, mode: str = 'exec') -> "ast.AST":
@@ -62,6 +62,16 @@ def parse(code: str, mode: str = 'exec') -> "ast.AST":
return ast.parse(code, mode=mode)
+@overload
+def unparse(node: None, code: str = '') -> None:
+ ...
+
+
+@overload
+def unparse(node: ast.AST, code: str = '') -> str:
+ ...
+
+
def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]:
"""Unparse an AST to string."""
if node is None:
@@ -98,7 +108,7 @@ class _UnparseVisitor(ast.NodeVisitor):
return name
def visit_arguments(self, node: ast.arguments) -> str:
- defaults = list(node.defaults)
+ defaults: List[Optional[ast.expr]] = list(node.defaults)
positionals = len(node.args)
posonlyargs = 0
if hasattr(node, "posonlyargs"): # for py38+
@@ -107,11 +117,11 @@ class _UnparseVisitor(ast.NodeVisitor):
for _ in range(len(defaults), positionals):
defaults.insert(0, None)
- kw_defaults = list(node.kw_defaults)
+ kw_defaults: List[Optional[ast.expr]] = list(node.kw_defaults)
for _ in range(len(kw_defaults), len(node.kwonlyargs)):
kw_defaults.insert(0, None)
- args = [] # type: List[str]
+ args: List[str] = []
if hasattr(node, "posonlyargs"): # for py38+
for i, arg in enumerate(node.posonlyargs): # type: ignore
args.append(self._visit_arg_with_default(arg, defaults[i]))
@@ -150,6 +160,17 @@ class _UnparseVisitor(ast.NodeVisitor):
["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
return "%s(%s)" % (self.visit(node.func), ", ".join(args))
+ def visit_Constant(self, node: ast.Constant) -> str: # type: ignore
+ if node.value is Ellipsis:
+ return "..."
+ elif isinstance(node.value, (int, float, complex)):
+ if self.code and sys.version_info > (3, 8):
+ return ast.get_source_segment(self.code, node) # type: ignore
+ else:
+ return repr(node.value)
+ else:
+ return repr(node.value)
+
def visit_Dict(self, node: ast.Dict) -> str:
keys = (self.visit(k) for k in node.keys)
values = (self.visit(v) for v in node.values)
@@ -197,18 +218,6 @@ class _UnparseVisitor(ast.NodeVisitor):
else:
return "()"
- if sys.version_info >= (3, 6):
- def visit_Constant(self, node: ast.Constant) -> str:
- if node.value is Ellipsis:
- return "..."
- elif isinstance(node.value, (int, float, complex)):
- if self.code and sys.version_info > (3, 8):
- return ast.get_source_segment(self.code, node)
- else:
- return repr(node.value)
- else:
- return repr(node.value)
-
if sys.version_info < (3, 8):
# these ast nodes were deprecated in python 3.8
def visit_Bytes(self, node: ast.Bytes) -> str: