summaryrefslogtreecommitdiff
path: root/sphinx/ext/inheritance_diagram.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/ext/inheritance_diagram.py')
-rw-r--r--sphinx/ext/inheritance_diagram.py114
1 files changed, 61 insertions, 53 deletions
diff --git a/sphinx/ext/inheritance_diagram.py b/sphinx/ext/inheritance_diagram.py
index 341780473..f37d7cf8b 100644
--- a/sphinx/ext/inheritance_diagram.py
+++ b/sphinx/ext/inheritance_diagram.py
@@ -64,9 +64,63 @@ if False:
from sphinx.environment import BuildEnvironment # NOQA
-class_sig_re = re.compile(r'''^([\w.]*\.)? # module names
- (\w+) \s* $ # class/final module name
- ''', re.VERBOSE)
+module_sig_re = re.compile(r'''^(?:([\w.]*)\.)? # module names
+ (\w+) \s* $ # class/final module name
+ ''', re.VERBOSE)
+
+
+def try_import(objname):
+ # type: (unicode) -> Any
+ """Import a object or module using *name* and *currentmodule*.
+ *name* should be a relative name from *currentmodule* or
+ a fully-qualified name.
+
+ Returns imported object or module. If failed, returns None value.
+ """
+ try:
+ __import__(objname)
+ return sys.modules.get(objname)
+ except ImportError:
+ modname, attrname = module_sig_re.match(objname).groups()
+ if modname is None:
+ return None
+ try:
+ __import__(modname)
+ return getattr(sys.modules.get(modname), attrname, None)
+ except ImportError:
+ return None
+
+
+def import_classes(name, currmodule):
+ # type: (unicode, unicode) -> Any
+ """Import a class using its fully-qualified *name*."""
+ target = None
+
+ # import class or module using currmodule
+ if currmodule:
+ target = try_import(currmodule + '.' + name)
+
+ # import class or module without currmodule
+ if target is None:
+ target = try_import(name)
+
+ if target is None:
+ raise InheritanceException(
+ 'Could not import class or module %r specified for '
+ 'inheritance diagram' % name)
+
+ if inspect.isclass(target):
+ # If imported object is a class, just return it
+ return [target]
+ elif inspect.ismodule(target):
+ # If imported object is a module, return classes defined on it
+ classes = []
+ for cls in target.__dict__.values():
+ if inspect.isclass(cls) and cls.__module__ == target.__name__:
+ classes.append(cls)
+ return classes
+ raise InheritanceException('%r specified for inheritance diagram is '
+ 'not a class or module' % name)
class InheritanceException(Exception):
@@ -95,58 +149,12 @@ class InheritanceGraph(object):
raise InheritanceException('No classes found for '
'inheritance diagram')
- def _import_class_or_module(self, name, currmodule):
- # type: (unicode, str) -> Any
- """Import a class using its fully-qualified *name*."""
- try:
- path, base = class_sig_re.match(name).groups() # type: ignore
- except (AttributeError, ValueError):
- raise InheritanceException('Invalid class or module %r specified '
- 'for inheritance diagram' % name)
-
- fullname = (path or '') + base
- path = (path and path.rstrip('.') or '')
-
- # two possibilities: either it is a module, then import it
- try:
- __import__(fullname)
- todoc = sys.modules[fullname]
- except ImportError:
- # else it is a class, then import the module
- if not path:
- if currmodule:
- # try the current module
- path = currmodule
- else:
- raise InheritanceException(
- 'Could not import class %r specified for '
- 'inheritance diagram' % base)
- try:
- __import__(path)
- todoc = getattr(sys.modules[path], base)
- except (ImportError, AttributeError):
- raise InheritanceException(
- 'Could not import class or module %r specified for '
- 'inheritance diagram' % (path + '.' + base))
-
- # If a class, just return it
- if inspect.isclass(todoc):
- return [todoc]
- elif inspect.ismodule(todoc):
- classes = []
- for cls in todoc.__dict__.values(): # type: ignore
- if inspect.isclass(cls) and cls.__module__ == todoc.__name__:
- classes.append(cls)
- return classes
- raise InheritanceException('%r specified for inheritance diagram is '
- 'not a class or module' % name)
-
def _import_classes(self, class_names, currmodule):
# type: (unicode, str) -> List[Any]
"""Import a list of classes."""
classes = [] # type: List[Any]
for name in class_names:
- classes.extend(self._import_class_or_module(name, currmodule))
+ classes.extend(import_classes(name, currmodule))
return classes
def _class_info(self, classes, show_builtins, private_bases, parts):
@@ -443,7 +451,7 @@ def setup(app):
man=(skip, None),
texinfo=(texinfo_visit_inheritance_diagram, None))
app.add_directive('inheritance-diagram', InheritanceDiagram)
- app.add_config_value('inheritance_graph_attrs', {}, False),
- app.add_config_value('inheritance_node_attrs', {}, False),
- app.add_config_value('inheritance_edge_attrs', {}, False),
+ app.add_config_value('inheritance_graph_attrs', {}, False)
+ app.add_config_value('inheritance_node_attrs', {}, False)
+ app.add_config_value('inheritance_edge_attrs', {}, False)
return {'version': sphinx.__display_version__, 'parallel_read_safe': True}