summaryrefslogtreecommitdiff
path: root/sphinx/pycode/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r--sphinx/pycode/ast.py213
1 files changed, 121 insertions, 92 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py
index 3f4717f45..0b403b61e 100644
--- a/sphinx/pycode/ast.py
+++ b/sphinx/pycode/ast.py
@@ -75,113 +75,142 @@ def unparse(node: Optional[ast.AST]) -> Optional[str]:
return None
elif isinstance(node, str):
return node
- elif node.__class__ in OPERATORS:
+ return _UnparseVisitor().visit(node)
+
+
+# a greatly cut-down version of `ast._Unparser`
+class _UnparseVisitor(ast.NodeVisitor):
+
+ def _visit_op(self, node: ast.AST) -> str:
return OPERATORS[node.__class__]
- elif isinstance(node, ast.arg):
+ for _op in OPERATORS:
+ locals()['visit_{}'.format(_op.__name__)] = _visit_op
+
+ def visit_arg(self, node: ast.arg) -> str:
if node.annotation:
- return "%s: %s" % (node.arg, unparse(node.annotation))
+ return "%s: %s" % (node.arg, self.visit(node.annotation))
else:
return node.arg
- elif isinstance(node, ast.arguments):
- return unparse_arguments(node)
- elif isinstance(node, ast.Attribute):
- return "%s.%s" % (unparse(node.value), node.attr)
- elif isinstance(node, ast.BinOp):
- return " ".join(unparse(e) for e in [node.left, node.op, node.right])
- elif isinstance(node, ast.BoolOp):
- op = " %s " % unparse(node.op)
- return op.join(unparse(e) for e in node.values)
- elif isinstance(node, ast.Bytes):
- return repr(node.s)
- elif isinstance(node, ast.Call):
- args = ([unparse(e) for e in node.args] +
- ["%s=%s" % (k.arg, unparse(k.value)) for k in node.keywords])
- return "%s(%s)" % (unparse(node.func), ", ".join(args))
- elif isinstance(node, ast.Dict):
- keys = (unparse(k) for k in node.keys)
- values = (unparse(v) for v in node.values)
+
+ def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str:
+ """Unparse a single argument to a string."""
+ name = self.visit(arg)
+ if default:
+ if arg.annotation:
+ name += " = %s" % self.visit(default)
+ else:
+ name += "=%s" % self.visit(default)
+ return name
+
+ def visit_arguments(self, node: ast.arguments) -> str:
+ defaults = list(node.defaults)
+ positionals = len(node.args)
+ posonlyargs = 0
+ if hasattr(node, "posonlyargs"): # for py38+
+ posonlyargs += len(node.posonlyargs) # type:ignore
+ positionals += posonlyargs
+ for _ in range(len(defaults), positionals):
+ defaults.insert(0, None)
+
+ kw_defaults = list(node.kw_defaults)
+ for _ in range(len(kw_defaults), len(node.kwonlyargs)):
+ kw_defaults.insert(0, None)
+
+ args = [] # type: List[str]
+ if hasattr(node, "posonlyargs"): # for py38+
+ for i, arg in enumerate(node.posonlyargs): # type: ignore
+ args.append(self._visit_arg_with_default(arg, defaults[i]))
+
+ if node.posonlyargs: # type: ignore
+ args.append('/')
+
+ for i, arg in enumerate(node.args):
+ args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
+
+ if node.vararg:
+ args.append("*" + self.visit(node.vararg))
+
+ if node.kwonlyargs and not node.vararg:
+ args.append('*')
+ for i, arg in enumerate(node.kwonlyargs):
+ args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
+
+ if node.kwarg:
+ args.append("**" + self.visit(node.kwarg))
+
+ return ", ".join(args)
+
+ def visit_Attribute(self, node: ast.Attribute) -> str:
+ return "%s.%s" % (self.visit(node.value), node.attr)
+
+ def visit_BinOp(self, node: ast.BinOp) -> str:
+ return " ".join(self.visit(e) for e in [node.left, node.op, node.right])
+
+ def visit_BoolOp(self, node: ast.BoolOp) -> str:
+ op = " %s " % self.visit(node.op)
+ return op.join(self.visit(e) for e in node.values)
+
+ def visit_Call(self, node: ast.Call) -> str:
+ args = ([self.visit(e) for e in node.args] +
+ ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
+ return "%s(%s)" % (self.visit(node.func), ", ".join(args))
+
+ def visit_Dict(self, node: ast.Dict) -> str:
+ keys = (self.visit(k) for k in node.keys)
+ values = (self.visit(v) for v in node.values)
items = (k + ": " + v for k, v in zip(keys, values))
return "{" + ", ".join(items) + "}"
- elif isinstance(node, ast.Ellipsis):
- return "..."
- elif isinstance(node, ast.Index):
- return unparse(node.value)
- elif isinstance(node, ast.Lambda):
- return "lambda %s: ..." % unparse(node.args)
- elif isinstance(node, ast.List):
- return "[" + ", ".join(unparse(e) for e in node.elts) + "]"
- elif isinstance(node, ast.Name):
- return node.id
- elif isinstance(node, ast.NameConstant):
- return repr(node.value)
- elif isinstance(node, ast.Num):
- return repr(node.n)
- elif isinstance(node, ast.Set):
- return "{" + ", ".join(unparse(e) for e in node.elts) + "}"
- elif isinstance(node, ast.Str):
- return repr(node.s)
- elif isinstance(node, ast.Subscript):
- return "%s[%s]" % (unparse(node.value), unparse(node.slice))
- elif isinstance(node, ast.UnaryOp):
- return "%s %s" % (unparse(node.op), unparse(node.operand))
- elif isinstance(node, ast.Tuple):
- if node.elts:
- return ", ".join(unparse(e) for e in node.elts)
- else:
- return "()"
- elif sys.version_info > (3, 6) and isinstance(node, ast.Constant):
- # this branch should be placed at last
- return repr(node.value)
- else:
- raise NotImplementedError('Unable to parse %s object' % type(node).__name__)
+ def visit_Index(self, node: ast.Index) -> str:
+ return self.visit(node.value)
-def _unparse_arg(arg: ast.arg, default: Optional[ast.AST]) -> str:
- """Unparse a single argument to a string."""
- name = unparse(arg)
- if default:
- if arg.annotation:
- name += " = %s" % unparse(default)
- else:
- name += "=%s" % unparse(default)
- return name
+ def visit_Lambda(self, node: ast.Lambda) -> str:
+ return "lambda %s: ..." % self.visit(node.args)
+ def visit_List(self, node: ast.List) -> str:
+ return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
-def unparse_arguments(node: ast.arguments) -> str:
- """Unparse an arguments to string."""
- defaults = list(node.defaults) # type: List[Optional[ast.AST]]
- positionals = len(node.args)
- posonlyargs = 0
- if hasattr(node, "posonlyargs"): # for py38+
- posonlyargs += len(node.posonlyargs) # type:ignore
- positionals += posonlyargs
- for _ in range(len(defaults), positionals):
- defaults.insert(0, None)
+ def visit_Name(self, node: ast.Name) -> str:
+ return node.id
- kw_defaults = list(node.kw_defaults) # type: List[Optional[ast.AST]]
- for _ in range(len(kw_defaults), len(node.kwonlyargs)):
- kw_defaults.insert(0, None)
+ def visit_Set(self, node: ast.Set) -> str:
+ return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
- args = [] # type: List[str]
- if hasattr(node, "posonlyargs"): # for py38+
- for i, arg in enumerate(node.posonlyargs): # type: ignore
- args.append(_unparse_arg(arg, defaults[i]))
+ def visit_Subscript(self, node: ast.Subscript) -> str:
+ return "%s[%s]" % (self.visit(node.value), self.visit(node.slice))
- if node.posonlyargs: # type: ignore
- args.append('/')
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
+ return "%s %s" % (self.visit(node.op), self.visit(node.operand))
- for i, arg in enumerate(node.args):
- args.append(_unparse_arg(arg, defaults[i + posonlyargs]))
+ def visit_Tuple(self, node: ast.Tuple) -> str:
+ if node.elts:
+ return ", ".join(self.visit(e) for e in node.elts)
+ else:
+ return "()"
- if node.vararg:
- args.append("*" + unparse(node.vararg))
+ if sys.version_info >= (3, 6):
+ def visit_Constant(self, node: ast.Constant) -> str:
+ if node.value is Ellipsis:
+ return "..."
+ else:
+ return repr(node.value)
- if node.kwonlyargs and not node.vararg:
- args.append('*')
- for i, arg in enumerate(node.kwonlyargs):
- args.append(_unparse_arg(arg, kw_defaults[i]))
+ if sys.version_info < (3, 8):
+ # these ast nodes were deprecated in python 3.8
+ def visit_Bytes(self, node: ast.Bytes) -> str:
+ return repr(node.s)
- if node.kwarg:
- args.append("**" + unparse(node.kwarg))
+ def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
+ return "..."
- return ", ".join(args)
+ def visit_NameConstant(self, node: ast.NameConstant) -> str:
+ return repr(node.value)
+
+ def visit_Num(self, node: ast.Num) -> str:
+ return repr(node.n)
+
+ def visit_Str(self, node: ast.Str) -> str:
+ return repr(node.s)
+
+ def generic_visit(self, node):
+ raise NotImplementedError('Unable to parse %s object' % type(node).__name__)