summaryrefslogtreecommitdiff
path: root/sphinx/pycode/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r--sphinx/pycode/ast.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py
index d131ff4c1..db9bfefb3 100644
--- a/sphinx/pycode/ast.py
+++ b/sphinx/pycode/ast.py
@@ -9,7 +9,7 @@
"""
import sys
-from typing import Dict, List, Optional, Type
+from typing import Dict, List, Optional, Type, overload
if sys.version_info > (3, 8):
import ast
@@ -58,6 +58,16 @@ def parse(code: str, mode: str = 'exec') -> "ast.AST":
return ast.parse(code, mode=mode)
+@overload
+def unparse(node: None, code: str = '') -> None:
+ ...
+
+
+@overload
+def unparse(node: ast.AST, code: str = '') -> str:
+ ...
+
+
def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]:
"""Unparse an AST to string."""
if node is None:
@@ -146,6 +156,17 @@ class _UnparseVisitor(ast.NodeVisitor):
["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
return "%s(%s)" % (self.visit(node.func), ", ".join(args))
+ def visit_Constant(self, node: ast.Constant) -> str: # type: ignore
+ if node.value is Ellipsis:
+ return "..."
+ elif isinstance(node.value, (int, float, complex)):
+ if self.code and sys.version_info > (3, 8):
+ return ast.get_source_segment(self.code, node) # type: ignore
+ else:
+ return repr(node.value)
+ else:
+ return repr(node.value)
+
def visit_Dict(self, node: ast.Dict) -> str:
keys = (self.visit(k) for k in node.keys)
values = (self.visit(v) for v in node.values)
@@ -193,18 +214,6 @@ class _UnparseVisitor(ast.NodeVisitor):
else:
return "()"
- if sys.version_info >= (3, 6):
- def visit_Constant(self, node: ast.Constant) -> str:
- if node.value is Ellipsis:
- return "..."
- elif isinstance(node.value, (int, float, complex)):
- if self.code and sys.version_info > (3, 8):
- return ast.get_source_segment(self.code, node)
- else:
- return repr(node.value)
- else:
- return repr(node.value)
-
if sys.version_info < (3, 8):
# these ast nodes were deprecated in python 3.8
def visit_Bytes(self, node: ast.Bytes) -> str: