diff options
Diffstat (limited to 'sphinx/pycode/ast.py')
-rw-r--r-- | sphinx/pycode/ast.py | 213 |
1 files changed, 121 insertions, 92 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py index 3f4717f45..0b403b61e 100644 --- a/sphinx/pycode/ast.py +++ b/sphinx/pycode/ast.py @@ -75,113 +75,142 @@ def unparse(node: Optional[ast.AST]) -> Optional[str]: return None elif isinstance(node, str): return node - elif node.__class__ in OPERATORS: + return _UnparseVisitor().visit(node) + + +# a greatly cut-down version of `ast._Unparser` +class _UnparseVisitor(ast.NodeVisitor): + + def _visit_op(self, node: ast.AST) -> str: return OPERATORS[node.__class__] - elif isinstance(node, ast.arg): + for _op in OPERATORS: + locals()['visit_{}'.format(_op.__name__)] = _visit_op + + def visit_arg(self, node: ast.arg) -> str: if node.annotation: - return "%s: %s" % (node.arg, unparse(node.annotation)) + return "%s: %s" % (node.arg, self.visit(node.annotation)) else: return node.arg - elif isinstance(node, ast.arguments): - return unparse_arguments(node) - elif isinstance(node, ast.Attribute): - return "%s.%s" % (unparse(node.value), node.attr) - elif isinstance(node, ast.BinOp): - return " ".join(unparse(e) for e in [node.left, node.op, node.right]) - elif isinstance(node, ast.BoolOp): - op = " %s " % unparse(node.op) - return op.join(unparse(e) for e in node.values) - elif isinstance(node, ast.Bytes): - return repr(node.s) - elif isinstance(node, ast.Call): - args = ([unparse(e) for e in node.args] + - ["%s=%s" % (k.arg, unparse(k.value)) for k in node.keywords]) - return "%s(%s)" % (unparse(node.func), ", ".join(args)) - elif isinstance(node, ast.Dict): - keys = (unparse(k) for k in node.keys) - values = (unparse(v) for v in node.values) + + def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str: + """Unparse a single argument to a string.""" + name = self.visit(arg) + if default: + if arg.annotation: + name += " = %s" % self.visit(default) + else: + name += "=%s" % self.visit(default) + return name + + def visit_arguments(self, node: ast.arguments) -> str: + defaults = list(node.defaults) + positionals = len(node.args) + posonlyargs = 0 + if hasattr(node, "posonlyargs"): # for py38+ + posonlyargs += len(node.posonlyargs) # type:ignore + positionals += posonlyargs + for _ in range(len(defaults), positionals): + defaults.insert(0, None) + + kw_defaults = list(node.kw_defaults) + for _ in range(len(kw_defaults), len(node.kwonlyargs)): + kw_defaults.insert(0, None) + + args = [] # type: List[str] + if hasattr(node, "posonlyargs"): # for py38+ + for i, arg in enumerate(node.posonlyargs): # type: ignore + args.append(self._visit_arg_with_default(arg, defaults[i])) + + if node.posonlyargs: # type: ignore + args.append('/') + + for i, arg in enumerate(node.args): + args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs])) + + if node.vararg: + args.append("*" + self.visit(node.vararg)) + + if node.kwonlyargs and not node.vararg: + args.append('*') + for i, arg in enumerate(node.kwonlyargs): + args.append(self._visit_arg_with_default(arg, kw_defaults[i])) + + if node.kwarg: + args.append("**" + self.visit(node.kwarg)) + + return ", ".join(args) + + def visit_Attribute(self, node: ast.Attribute) -> str: + return "%s.%s" % (self.visit(node.value), node.attr) + + def visit_BinOp(self, node: ast.BinOp) -> str: + return " ".join(self.visit(e) for e in [node.left, node.op, node.right]) + + def visit_BoolOp(self, node: ast.BoolOp) -> str: + op = " %s " % self.visit(node.op) + return op.join(self.visit(e) for e in node.values) + + def visit_Call(self, node: ast.Call) -> str: + args = ([self.visit(e) for e in node.args] + + ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords]) + return "%s(%s)" % (self.visit(node.func), ", ".join(args)) + + def visit_Dict(self, node: ast.Dict) -> str: + keys = (self.visit(k) for k in node.keys) + values = (self.visit(v) for v in node.values) items = (k + ": " + v for k, v in zip(keys, values)) return "{" + ", ".join(items) + "}" - elif isinstance(node, ast.Ellipsis): - return "..." - elif isinstance(node, ast.Index): - return unparse(node.value) - elif isinstance(node, ast.Lambda): - return "lambda %s: ..." % unparse(node.args) - elif isinstance(node, ast.List): - return "[" + ", ".join(unparse(e) for e in node.elts) + "]" - elif isinstance(node, ast.Name): - return node.id - elif isinstance(node, ast.NameConstant): - return repr(node.value) - elif isinstance(node, ast.Num): - return repr(node.n) - elif isinstance(node, ast.Set): - return "{" + ", ".join(unparse(e) for e in node.elts) + "}" - elif isinstance(node, ast.Str): - return repr(node.s) - elif isinstance(node, ast.Subscript): - return "%s[%s]" % (unparse(node.value), unparse(node.slice)) - elif isinstance(node, ast.UnaryOp): - return "%s %s" % (unparse(node.op), unparse(node.operand)) - elif isinstance(node, ast.Tuple): - if node.elts: - return ", ".join(unparse(e) for e in node.elts) - else: - return "()" - elif sys.version_info > (3, 6) and isinstance(node, ast.Constant): - # this branch should be placed at last - return repr(node.value) - else: - raise NotImplementedError('Unable to parse %s object' % type(node).__name__) + def visit_Index(self, node: ast.Index) -> str: + return self.visit(node.value) -def _unparse_arg(arg: ast.arg, default: Optional[ast.AST]) -> str: - """Unparse a single argument to a string.""" - name = unparse(arg) - if default: - if arg.annotation: - name += " = %s" % unparse(default) - else: - name += "=%s" % unparse(default) - return name + def visit_Lambda(self, node: ast.Lambda) -> str: + return "lambda %s: ..." % self.visit(node.args) + def visit_List(self, node: ast.List) -> str: + return "[" + ", ".join(self.visit(e) for e in node.elts) + "]" -def unparse_arguments(node: ast.arguments) -> str: - """Unparse an arguments to string.""" - defaults = list(node.defaults) # type: List[Optional[ast.AST]] - positionals = len(node.args) - posonlyargs = 0 - if hasattr(node, "posonlyargs"): # for py38+ - posonlyargs += len(node.posonlyargs) # type:ignore - positionals += posonlyargs - for _ in range(len(defaults), positionals): - defaults.insert(0, None) + def visit_Name(self, node: ast.Name) -> str: + return node.id - kw_defaults = list(node.kw_defaults) # type: List[Optional[ast.AST]] - for _ in range(len(kw_defaults), len(node.kwonlyargs)): - kw_defaults.insert(0, None) + def visit_Set(self, node: ast.Set) -> str: + return "{" + ", ".join(self.visit(e) for e in node.elts) + "}" - args = [] # type: List[str] - if hasattr(node, "posonlyargs"): # for py38+ - for i, arg in enumerate(node.posonlyargs): # type: ignore - args.append(_unparse_arg(arg, defaults[i])) + def visit_Subscript(self, node: ast.Subscript) -> str: + return "%s[%s]" % (self.visit(node.value), self.visit(node.slice)) - if node.posonlyargs: # type: ignore - args.append('/') + def visit_UnaryOp(self, node: ast.UnaryOp) -> str: + return "%s %s" % (self.visit(node.op), self.visit(node.operand)) - for i, arg in enumerate(node.args): - args.append(_unparse_arg(arg, defaults[i + posonlyargs])) + def visit_Tuple(self, node: ast.Tuple) -> str: + if node.elts: + return ", ".join(self.visit(e) for e in node.elts) + else: + return "()" - if node.vararg: - args.append("*" + unparse(node.vararg)) + if sys.version_info >= (3, 6): + def visit_Constant(self, node: ast.Constant) -> str: + if node.value is Ellipsis: + return "..." + else: + return repr(node.value) - if node.kwonlyargs and not node.vararg: - args.append('*') - for i, arg in enumerate(node.kwonlyargs): - args.append(_unparse_arg(arg, kw_defaults[i])) + if sys.version_info < (3, 8): + # these ast nodes were deprecated in python 3.8 + def visit_Bytes(self, node: ast.Bytes) -> str: + return repr(node.s) - if node.kwarg: - args.append("**" + unparse(node.kwarg)) + def visit_Ellipsis(self, node: ast.Ellipsis) -> str: + return "..." - return ", ".join(args) + def visit_NameConstant(self, node: ast.NameConstant) -> str: + return repr(node.value) + + def visit_Num(self, node: ast.Num) -> str: + return repr(node.n) + + def visit_Str(self, node: ast.Str) -> str: + return repr(node.s) + + def generic_visit(self, node): + raise NotImplementedError('Unable to parse %s object' % type(node).__name__) |