summaryrefslogtreecommitdiff
path: root/sphinx/util/nodes.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/util/nodes.py')
-rw-r--r--sphinx/util/nodes.py58
1 files changed, 53 insertions, 5 deletions
diff --git a/sphinx/util/nodes.py b/sphinx/util/nodes.py
index 7e4dba01e..9d500de76 100644
--- a/sphinx/util/nodes.py
+++ b/sphinx/util/nodes.py
@@ -11,6 +11,7 @@
from __future__ import absolute_import
import re
+from typing import Any
from docutils import nodes
from six import text_type
@@ -33,6 +34,57 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?<!\x00)<(.*?)>$', re.DOTALL)
caption_ref_re = explicit_title_re # b/w compat alias
+class NodeMatcher(object):
+ """A helper class for Node.traverse().
+
+ It checks that given node is an instance of specified node-classes and it has
+ specified node-attributes.
+
+ For example, following example searches ``reference`` node having ``refdomain``
+ and ``reftype`` attributes::
+
+ matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation')
+ doctree.traverse(matcher)
+ # => [<reference ...>, <reference ...>, ...]
+
+ A special value ``typing.Any`` matches any kind of node-attributes. For example,
+ following example searches ``reference`` node having ``refdomain`` attributes::
+
+ from typing import Any
+ matcher = NodeMatcher(nodes.reference, refdomain=Any)
+ doctree.traverse(matcher)
+ # => [<reference ...>, <reference ...>, ...]
+ """
+
+ def __init__(self, *classes, **attrs):
+ # type: (nodes.Node, Any) -> None
+ self.classes = classes
+ self.attrs = attrs
+
+ def match(self, node):
+ # type: (nodes.Node) -> bool
+ try:
+ if self.classes and not isinstance(node, self.classes):
+ return False
+
+ for key, value in self.attrs.items():
+ if key not in node:
+ return False
+ elif value is Any:
+ continue
+ elif node.get(key) != value:
+ return False
+ else:
+ return True
+ except Exception:
+ # for non-Element nodes
+ return False
+
+ def __call__(self, node):
+ # type: (nodes.Node) -> bool
+ return self.match(node)
+
+
def get_full_module_name(node):
# type: (nodes.Node) -> str
"""
@@ -241,11 +293,7 @@ def traverse_parent(node, cls=None):
def traverse_translatable_index(doctree):
# type: (nodes.Node) -> Iterable[Tuple[nodes.Node, List[unicode]]]
"""Traverse translatable index node from a document tree."""
- def is_block_index(node):
- # type: (nodes.Node) -> bool
- return isinstance(node, addnodes.index) and \
- node.get('inline') is False
- for node in doctree.traverse(is_block_index):
+ for node in doctree.traverse(NodeMatcher(addnodes.index, inline=False)):
if 'raw_entries' in node:
entries = node['raw_entries']
else: