summaryrefslogtreecommitdiff
path: root/sphinx/transforms/post_transforms/code.py
blob: 0707b85a84d02ce1b1aa530790988ccc006bd918 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
    sphinx.transforms.post_transforms.code
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    transforms for code-blocks.

    :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS.
    :license: BSD, see LICENSE for details.
"""

import sys
from typing import Any, Dict, List, NamedTuple, Union

from docutils import nodes
from docutils.nodes import Node
from pygments.lexers import PythonConsoleLexer, guess_lexer

from sphinx import addnodes
from sphinx.application import Sphinx
from sphinx.ext import doctest
from sphinx.transforms import SphinxTransform


HighlightSetting = NamedTuple('HighlightSetting', [('language', str),
                                                   ('force', bool),
                                                   ('lineno_threshold', int)])


class HighlightLanguageTransform(SphinxTransform):
    """
    Apply highlight_language to all literal_block nodes.

    This refers both :confval:`highlight_language` setting and
    :rst:dir:`highlightlang` directive.  After processing, this transform
    removes ``highlightlang`` node from doctree.
    """
    default_priority = 400

    def apply(self, **kwargs) -> None:
        visitor = HighlightLanguageVisitor(self.document,
                                           self.config.highlight_language)
        self.document.walkabout(visitor)

        for node in self.document.traverse(addnodes.highlightlang):
            node.parent.remove(node)


class HighlightLanguageVisitor(nodes.NodeVisitor):
    def __init__(self, document: nodes.document, default_language: str) -> None:
        self.default_setting = HighlightSetting(default_language, False, sys.maxsize)
        self.settings = []  # type: List[HighlightSetting]
        super().__init__(document)

    def unknown_visit(self, node: Node) -> None:
        pass

    def unknown_departure(self, node: Node) -> None:
        pass

    def visit_document(self, node: Node) -> None:
        self.settings.append(self.default_setting)

    def depart_document(self, node: Node) -> None:
        self.settings.pop()

    def visit_start_of_file(self, node: Node) -> None:
        self.settings.append(self.default_setting)

    def depart_start_of_file(self, node: Node) -> None:
        self.settings.pop()

    def visit_highlightlang(self, node: addnodes.highlightlang) -> None:
        self.settings[-1] = HighlightSetting(node['lang'],
                                             node['force'],
                                             node['linenothreshold'])

    def visit_literal_block(self, node: nodes.literal_block) -> None:
        setting = self.settings[-1]
        if 'language' not in node:
            node['language'] = setting.language
            node['force'] = setting.force
        if 'linenos' not in node:
            lines = node.astext().count('\n')
            node['linenos'] = (lines >= setting.lineno_threshold - 1)


class TrimDoctestFlagsTransform(SphinxTransform):
    """
    Trim doctest flags like ``# doctest: +FLAG`` from python code-blocks.

    see :confval:`trim_doctest_flags` for more information.
    """
    default_priority = HighlightLanguageTransform.default_priority + 1

    def apply(self, **kwargs) -> None:
        if not self.config.trim_doctest_flags:
            return

        for lbnode in self.document.traverse(nodes.literal_block):  # type: nodes.literal_block
            if self.is_pyconsole(lbnode):
                self.strip_doctest_flags(lbnode)

        for dbnode in self.document.traverse(nodes.doctest_block):  # type: nodes.doctest_block
            self.strip_doctest_flags(dbnode)

    @staticmethod
    def strip_doctest_flags(node: Union[nodes.literal_block, nodes.doctest_block]) -> None:
        source = node.rawsource
        source = doctest.blankline_re.sub('', source)
        source = doctest.doctestopt_re.sub('', source)
        node.rawsource = source
        node[:] = [nodes.Text(source)]

    @staticmethod
    def is_pyconsole(node: nodes.literal_block) -> bool:
        if node.rawsource != node.astext():
            return False  # skip parsed-literal node

        language = node.get('language')
        if language in ('pycon', 'pycon3'):
            return True
        elif language in ('py', 'py3', 'python', 'python3', 'default'):
            return node.rawsource.startswith('>>>')
        elif language == 'guess':
            try:
                lexer = guess_lexer(node.rawsource)
                return isinstance(lexer, PythonConsoleLexer)
            except Exception:
                pass

        return False


def setup(app: Sphinx) -> Dict[str, Any]:
    app.add_post_transform(HighlightLanguageTransform)
    app.add_post_transform(TrimDoctestFlagsTransform)

    return {
        'version': 'builtin',
        'parallel_read_safe': True,
        'parallel_write_safe': True,
    }