diff options
Diffstat (limited to 'sphinx/ext/inheritance_diagram.py')
-rw-r--r-- | sphinx/ext/inheritance_diagram.py | 114 |
1 files changed, 61 insertions, 53 deletions
diff --git a/sphinx/ext/inheritance_diagram.py b/sphinx/ext/inheritance_diagram.py index 341780473..f37d7cf8b 100644 --- a/sphinx/ext/inheritance_diagram.py +++ b/sphinx/ext/inheritance_diagram.py @@ -64,9 +64,63 @@ if False: from sphinx.environment import BuildEnvironment # NOQA -class_sig_re = re.compile(r'''^([\w.]*\.)? # module names - (\w+) \s* $ # class/final module name - ''', re.VERBOSE) +module_sig_re = re.compile(r'''^(?:([\w.]*)\.)? # module names + (\w+) \s* $ # class/final module name + ''', re.VERBOSE) + + +def try_import(objname): + # type: (unicode) -> Any + """Import a object or module using *name* and *currentmodule*. + *name* should be a relative name from *currentmodule* or + a fully-qualified name. + + Returns imported object or module. If failed, returns None value. + """ + try: + __import__(objname) + return sys.modules.get(objname) + except ImportError: + modname, attrname = module_sig_re.match(objname).groups() + if modname is None: + return None + try: + __import__(modname) + return getattr(sys.modules.get(modname), attrname, None) + except ImportError: + return None + + +def import_classes(name, currmodule): + # type: (unicode, unicode) -> Any + """Import a class using its fully-qualified *name*.""" + target = None + + # import class or module using currmodule + if currmodule: + target = try_import(currmodule + '.' + name) + + # import class or module without currmodule + if target is None: + target = try_import(name) + + if target is None: + raise InheritanceException( + 'Could not import class or module %r specified for ' + 'inheritance diagram' % name) + + if inspect.isclass(target): + # If imported object is a class, just return it + return [target] + elif inspect.ismodule(target): + # If imported object is a module, return classes defined on it + classes = [] + for cls in target.__dict__.values(): + if inspect.isclass(cls) and cls.__module__ == target.__name__: + classes.append(cls) + return classes + raise InheritanceException('%r specified for inheritance diagram is ' + 'not a class or module' % name) class InheritanceException(Exception): @@ -95,58 +149,12 @@ class InheritanceGraph(object): raise InheritanceException('No classes found for ' 'inheritance diagram') - def _import_class_or_module(self, name, currmodule): - # type: (unicode, str) -> Any - """Import a class using its fully-qualified *name*.""" - try: - path, base = class_sig_re.match(name).groups() # type: ignore - except (AttributeError, ValueError): - raise InheritanceException('Invalid class or module %r specified ' - 'for inheritance diagram' % name) - - fullname = (path or '') + base - path = (path and path.rstrip('.') or '') - - # two possibilities: either it is a module, then import it - try: - __import__(fullname) - todoc = sys.modules[fullname] - except ImportError: - # else it is a class, then import the module - if not path: - if currmodule: - # try the current module - path = currmodule - else: - raise InheritanceException( - 'Could not import class %r specified for ' - 'inheritance diagram' % base) - try: - __import__(path) - todoc = getattr(sys.modules[path], base) - except (ImportError, AttributeError): - raise InheritanceException( - 'Could not import class or module %r specified for ' - 'inheritance diagram' % (path + '.' + base)) - - # If a class, just return it - if inspect.isclass(todoc): - return [todoc] - elif inspect.ismodule(todoc): - classes = [] - for cls in todoc.__dict__.values(): # type: ignore - if inspect.isclass(cls) and cls.__module__ == todoc.__name__: - classes.append(cls) - return classes - raise InheritanceException('%r specified for inheritance diagram is ' - 'not a class or module' % name) - def _import_classes(self, class_names, currmodule): # type: (unicode, str) -> List[Any] """Import a list of classes.""" classes = [] # type: List[Any] for name in class_names: - classes.extend(self._import_class_or_module(name, currmodule)) + classes.extend(import_classes(name, currmodule)) return classes def _class_info(self, classes, show_builtins, private_bases, parts): @@ -443,7 +451,7 @@ def setup(app): man=(skip, None), texinfo=(texinfo_visit_inheritance_diagram, None)) app.add_directive('inheritance-diagram', InheritanceDiagram) - app.add_config_value('inheritance_graph_attrs', {}, False), - app.add_config_value('inheritance_node_attrs', {}, False), - app.add_config_value('inheritance_edge_attrs', {}, False), + app.add_config_value('inheritance_graph_attrs', {}, False) + app.add_config_value('inheritance_node_attrs', {}, False) + app.add_config_value('inheritance_edge_attrs', {}, False) return {'version': sphinx.__display_version__, 'parallel_read_safe': True} |