summaryrefslogtreecommitdiff
path: root/sphinx/ext/autodoc/preserve_defaults.py
blob: a0ceb1ac28099bc84c06c14e4d132adeff63a687 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Preserve function defaults.

Preserve the default argument values of function signatures in source code
and keep them not evaluated for readability.
"""

from __future__ import annotations

import ast
import inspect
from typing import Any

import sphinx
from sphinx.application import Sphinx
from sphinx.locale import __
from sphinx.pycode.ast import unparse as ast_unparse
from sphinx.util import logging

logger = logging.getLogger(__name__)


class DefaultValue:
    def __init__(self, name: str) -> None:
        self.name = name

    def __repr__(self) -> str:
        return self.name


def get_function_def(obj: Any) -> ast.FunctionDef | None:
    """Get FunctionDef object from living object.
    This tries to parse original code for living object and returns
    AST node for given *obj*.
    """
    try:
        source = inspect.getsource(obj)
        if source.startswith((' ', r'\t')):
            # subject is placed inside class or block.  To read its docstring,
            # this adds if-block before the declaration.
            module = ast.parse('if True:\n' + source)
            return module.body[0].body[0]  # type: ignore
        else:
            module = ast.parse(source)
            return module.body[0]  # type: ignore
    except (OSError, TypeError):  # failed to load source code
        return None


def get_default_value(lines: list[str], position: ast.AST) -> str | None:
    try:
        if position.lineno == position.end_lineno:
            line = lines[position.lineno - 1]
            return line[position.col_offset:position.end_col_offset]
        else:
            # multiline value is not supported now
            return None
    except (AttributeError, IndexError):
        return None


def update_defvalue(app: Sphinx, obj: Any, bound_method: bool) -> None:
    """Update defvalue info of *obj* using type_comments."""
    if not app.config.autodoc_preserve_defaults:
        return

    try:
        lines = inspect.getsource(obj).splitlines()
        if lines[0].startswith((' ', r'\t')):
            lines.insert(0, '')  # insert a dummy line to follow what get_function_def() does.
    except (OSError, TypeError):
        lines = []

    try:
        function = get_function_def(obj)
        if function.args.defaults or function.args.kw_defaults:
            sig = inspect.signature(obj)
            defaults = list(function.args.defaults)
            kw_defaults = list(function.args.kw_defaults)
            parameters = list(sig.parameters.values())
            for i, param in enumerate(parameters):
                if param.default is param.empty:
                    if param.kind == param.KEYWORD_ONLY:
                        # Consume kw_defaults for kwonly args
                        kw_defaults.pop(0)
                else:
                    if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
                        default = defaults.pop(0)
                        value = get_default_value(lines, default)
                        if value is None:
                            value = ast_unparse(default)
                        parameters[i] = param.replace(default=DefaultValue(value))
                    else:
                        default = kw_defaults.pop(0)
                        value = get_default_value(lines, default)
                        if value is None:
                            value = ast_unparse(default)
                        parameters[i] = param.replace(default=DefaultValue(value))

            if bound_method and inspect.ismethod(obj):
                # classmethods
                cls = inspect.Parameter('cls', inspect.Parameter.POSITIONAL_OR_KEYWORD)
                parameters.insert(0, cls)

            sig = sig.replace(parameters=parameters)
            if bound_method and inspect.ismethod(obj):
                # classmethods can't be assigned __signature__ attribute.
                obj.__dict__['__signature__'] = sig
            else:
                obj.__signature__ = sig
    except (AttributeError, TypeError):
        # failed to update signature (ex. built-in or extension types)
        pass
    except NotImplementedError as exc:  # failed to ast.unparse()
        logger.warning(__("Failed to parse a default argument value for %r: %s"), obj, exc)


def setup(app: Sphinx) -> dict[str, Any]:
    app.add_config_value('autodoc_preserve_defaults', False, True)
    app.connect('autodoc-before-process-signature', update_defvalue)

    return {
        'version': sphinx.__display_version__,
        'parallel_read_safe': True,
    }