diff options
Diffstat (limited to 'sphinx/util/nodes.py')
-rw-r--r-- | sphinx/util/nodes.py | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/sphinx/util/nodes.py b/sphinx/util/nodes.py index 7e4dba01e..9d500de76 100644 --- a/sphinx/util/nodes.py +++ b/sphinx/util/nodes.py @@ -11,6 +11,7 @@ from __future__ import absolute_import import re +from typing import Any from docutils import nodes from six import text_type @@ -33,6 +34,57 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?<!\x00)<(.*?)>$', re.DOTALL) caption_ref_re = explicit_title_re # b/w compat alias +class NodeMatcher(object): + """A helper class for Node.traverse(). + + It checks that given node is an instance of specified node-classes and it has + specified node-attributes. + + For example, following example searches ``reference`` node having ``refdomain`` + and ``reftype`` attributes:: + + matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation') + doctree.traverse(matcher) + # => [<reference ...>, <reference ...>, ...] + + A special value ``typing.Any`` matches any kind of node-attributes. For example, + following example searches ``reference`` node having ``refdomain`` attributes:: + + from typing import Any + matcher = NodeMatcher(nodes.reference, refdomain=Any) + doctree.traverse(matcher) + # => [<reference ...>, <reference ...>, ...] + """ + + def __init__(self, *classes, **attrs): + # type: (nodes.Node, Any) -> None + self.classes = classes + self.attrs = attrs + + def match(self, node): + # type: (nodes.Node) -> bool + try: + if self.classes and not isinstance(node, self.classes): + return False + + for key, value in self.attrs.items(): + if key not in node: + return False + elif value is Any: + continue + elif node.get(key) != value: + return False + else: + return True + except Exception: + # for non-Element nodes + return False + + def __call__(self, node): + # type: (nodes.Node) -> bool + return self.match(node) + + def get_full_module_name(node): # type: (nodes.Node) -> str """ @@ -241,11 +293,7 @@ def traverse_parent(node, cls=None): def traverse_translatable_index(doctree): # type: (nodes.Node) -> Iterable[Tuple[nodes.Node, List[unicode]]] """Traverse translatable index node from a document tree.""" - def is_block_index(node): - # type: (nodes.Node) -> bool - return isinstance(node, addnodes.index) and \ - node.get('inline') is False - for node in doctree.traverse(is_block_index): + for node in doctree.traverse(NodeMatcher(addnodes.index, inline=False)): if 'raw_entries' in node: entries = node['raw_entries'] else: |