diff options
Diffstat (limited to 'sphinx/ext/inheritance_diagram.py')
-rw-r--r-- | sphinx/ext/inheritance_diagram.py | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/sphinx/ext/inheritance_diagram.py b/sphinx/ext/inheritance_diagram.py index 34fe7bea6..8e01a1b28 100644 --- a/sphinx/ext/inheritance_diagram.py +++ b/sphinx/ext/inheritance_diagram.py @@ -39,10 +39,7 @@ r""" import re import sys import inspect -try: - from hashlib import md5 -except ImportError: - from md5 import md5 +from hashlib import md5 from six import text_type from six.moves import builtins @@ -58,7 +55,7 @@ from sphinx.util import force_decode if False: # For type annotation - from typing import Any, Dict, List, Tuple # NOQA + from typing import Any, Dict, List, Tuple, Dict, Optional # NOQA from sphinx.application import Sphinx # NOQA from sphinx.environment import BuildEnvironment # NOQA @@ -133,8 +130,8 @@ class InheritanceGraph(object): graphviz dot graph from them. """ def __init__(self, class_names, currmodule, show_builtins=False, - private_bases=False, parts=0): - # type: (unicode, str, bool, bool, int) -> None + private_bases=False, parts=0, aliases=None): + # type: (unicode, str, bool, bool, int, Optional[Dict[unicode, unicode]]) -> None """*class_names* is a list of child classes to show bases from. If *show_builtins* is True, then Python builtins will be shown @@ -143,7 +140,7 @@ class InheritanceGraph(object): self.class_names = class_names classes = self._import_classes(class_names, currmodule) self.class_info = self._class_info(classes, show_builtins, - private_bases, parts) + private_bases, parts, aliases) if not self.class_info: raise InheritanceException('No classes found for ' 'inheritance diagram') @@ -156,8 +153,8 @@ class InheritanceGraph(object): classes.extend(import_classes(name, currmodule)) return classes - def _class_info(self, classes, show_builtins, private_bases, parts): - # type: (List[Any], bool, bool, int) -> List[Tuple[unicode, unicode, List[unicode], unicode]] # NOQA + def _class_info(self, classes, show_builtins, private_bases, parts, aliases): + # type: (List[Any], bool, bool, int, Optional[Dict[unicode, unicode]]) -> List[Tuple[unicode, unicode, List[unicode], unicode]] # NOQA """Return name and bases for all classes that are ancestors of *classes*. @@ -174,8 +171,8 @@ class InheritanceGraph(object): if not private_bases and cls.__name__.startswith('_'): return - nodename = self.class_name(cls, parts) - fullname = self.class_name(cls, 0) + nodename = self.class_name(cls, parts, aliases) + fullname = self.class_name(cls, 0, aliases) # Use first line of docstring as tooltip, if available tooltip = None @@ -197,7 +194,7 @@ class InheritanceGraph(object): continue if not private_bases and base.__name__.startswith('_'): continue - baselist.append(self.class_name(base, parts)) + baselist.append(self.class_name(base, parts, aliases)) if base not in all_classes: recurse(base) @@ -206,8 +203,8 @@ class InheritanceGraph(object): return list(all_classes.values()) - def class_name(self, cls, parts=0): - # type: (Any, int) -> unicode + def class_name(self, cls, parts=0, aliases=None): + # type: (Any, int, Optional[Dict[unicode, unicode]]) -> unicode """Given a class object, return a fully-qualified name. This works for things I've tested in matplotlib so far, but may not be @@ -219,9 +216,13 @@ class InheritanceGraph(object): else: fullname = '%s.%s' % (module, cls.__name__) if parts == 0: - return fullname - name_parts = fullname.split('.') - return '.'.join(name_parts[-parts:]) + result = fullname + else: + name_parts = fullname.split('.') + result = '.'.join(name_parts[-parts:]) + if aliases is not None and result in aliases: + return aliases[result] + return result def get_all_class_names(self): # type: () -> List[unicode] @@ -339,7 +340,8 @@ class InheritanceDiagram(Directive): graph = InheritanceGraph( class_names, env.ref_context.get('py:module'), parts=node['parts'], - private_bases='private-bases' in self.options) + private_bases='private-bases' in self.options, + aliases=env.config.inheritance_alias) except InheritanceException as err: return [node.document.reporter.warning(err.args[0], line=self.lineno)] @@ -453,4 +455,5 @@ def setup(app): app.add_config_value('inheritance_graph_attrs', {}, False) app.add_config_value('inheritance_node_attrs', {}, False) app.add_config_value('inheritance_edge_attrs', {}, False) + app.add_config_value('inheritance_alias', {}, False) return {'version': sphinx.__display_version__, 'parallel_read_safe': True} |