summaryrefslogtreecommitdiff
path: root/sphinx/ext/autodoc/importer.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/ext/autodoc/importer.py')
-rw-r--r--sphinx/ext/autodoc/importer.py103
1 files changed, 77 insertions, 26 deletions
diff --git a/sphinx/ext/autodoc/importer.py b/sphinx/ext/autodoc/importer.py
index a2280e82b..be971adbb 100644
--- a/sphinx/ext/autodoc/importer.py
+++ b/sphinx/ext/autodoc/importer.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""
sphinx.ext.autodoc.importer
~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -10,25 +9,27 @@
"""
import contextlib
+import os
import sys
import traceback
import warnings
from collections import namedtuple
+from importlib.abc import Loader, MetaPathFinder
+from importlib.machinery import ModuleSpec
from types import FunctionType, MethodType, ModuleType
-from six import PY2, iteritems
-
+from sphinx.deprecation import RemovedInSphinx30Warning
from sphinx.util import logging
from sphinx.util.inspect import isenumclass, safe_getattr
if False:
# For type annotation
- from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple # NOQA
+ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union # NOQA
logger = logging.getLogger(__name__)
-class _MockObject(object):
+class _MockObject:
"""Used by autodoc_mock_imports."""
def __new__(cls, *args, **kwargs):
@@ -77,15 +78,18 @@ class _MockObject(object):
class _MockModule(ModuleType):
"""Used by autodoc_mock_imports."""
- __file__ = '/dev/null'
+ __file__ = os.devnull
- def __init__(self, name, loader):
+ def __init__(self, name, loader=None):
# type: (str, _MockImporter) -> None
- self.__name__ = self.__package__ = name
- self.__loader__ = loader
+ super().__init__(name)
self.__all__ = [] # type: List[str]
self.__path__ = [] # type: List[str]
+ if loader is not None:
+ warnings.warn('The loader argument for _MockModule is deprecated.',
+ RemovedInSphinx30Warning)
+
def __getattr__(self, name):
# type: (str) -> _MockObject
o = _MockObject()
@@ -93,7 +97,7 @@ class _MockModule(ModuleType):
return o
-class _MockImporter(object):
+class _MockImporter(MetaPathFinder):
def __init__(self, names):
# type: (List[str]) -> None
self.names = names
@@ -101,6 +105,9 @@ class _MockImporter(object):
# enable hook by adding itself to meta_path
sys.meta_path.insert(0, self)
+ warnings.warn('_MockImporter is now deprecated.',
+ RemovedInSphinx30Warning)
+
def disable(self):
# type: () -> None
# remove `self` from `sys.meta_path` to disable import hook
@@ -112,7 +119,7 @@ class _MockImporter(object):
del sys.modules[m]
def find_module(self, name, path=None):
- # type: (str, str) -> Any
+ # type: (str, Sequence[Union[bytes, str]]) -> Any
# check if name is (or is a descendant of) one of our base_packages
for n in self.names:
if n == name or name.startswith(n + '.'):
@@ -132,14 +139,66 @@ class _MockImporter(object):
return module
+class MockLoader(Loader):
+ """A loader for mocking."""
+ def __init__(self, finder):
+ # type: (MockFinder) -> None
+ super().__init__()
+ self.finder = finder
+
+ def create_module(self, spec):
+ # type: (ModuleSpec) -> ModuleType
+ logger.debug('[autodoc] adding a mock module as %s!', spec.name)
+ self.finder.mocked_modules.append(spec.name)
+ return _MockModule(spec.name)
+
+ def exec_module(self, module):
+ # type: (ModuleType) -> None
+ pass # nothing to do
+
+
+class MockFinder(MetaPathFinder):
+ """A finder for mocking."""
+
+ def __init__(self, modnames):
+ # type: (List[str]) -> None
+ super().__init__()
+ self.modnames = modnames
+ self.loader = MockLoader(self)
+ self.mocked_modules = [] # type: List[str]
+
+ def find_spec(self, fullname, path, target=None):
+ # type: (str, Sequence[Union[bytes, str]], ModuleType) -> ModuleSpec
+ for modname in self.modnames:
+ # check if fullname is (or is a descendant of) one of our targets
+ if modname == fullname or fullname.startswith(modname + '.'):
+ return ModuleSpec(fullname, self.loader)
+
+ return None
+
+ def invalidate_caches(self):
+ # type: () -> None
+ """Invalidate mocked modules on sys.modules."""
+ for modname in self.mocked_modules:
+ sys.modules.pop(modname, None)
+
+
@contextlib.contextmanager
-def mock(names):
- # type: (List[str]) -> Generator
+def mock(modnames):
+ # type: (List[str]) -> Generator[None, None, None]
+ """Insert mock modules during context::
+
+ with mock(['target.module.name']):
+ # mock modules are enabled here
+ ...
+ """
try:
- importer = _MockImporter(names)
+ finder = MockFinder(modnames)
+ sys.meta_path.insert(0, finder)
yield
finally:
- importer.disable()
+ sys.meta_path.remove(finder)
+ finder.invalidate_caches()
def import_module(modname, warningiserror=False):
@@ -160,7 +219,7 @@ def import_module(modname, warningiserror=False):
def import_object(modname, objpath, objtype='', attrgetter=safe_getattr, warningiserror=False):
- # type: (str, List[unicode], str, Callable[[Any, unicode], Any], bool) -> Any
+ # type: (str, List[str], str, Callable[[Any, str], Any], bool) -> Any
if objpath:
logger.debug('[autodoc] from %s import %s', modname, '.'.join(objpath))
else:
@@ -219,8 +278,6 @@ def import_object(modname, objpath, objtype='', attrgetter=safe_getattr, warning
else:
errmsg += '; the following exception was raised:\n%s' % traceback.format_exc()
- if PY2:
- errmsg = errmsg.decode('utf-8') # type: ignore
logger.debug(errmsg)
raise ImportError(errmsg)
@@ -229,17 +286,11 @@ Attribute = namedtuple('Attribute', ['name', 'directly_defined', 'value'])
def get_object_members(subject, objpath, attrgetter, analyzer=None):
- # type: (Any, List[unicode], Callable, Any) -> Dict[str, Attribute] # NOQA
+ # type: (Any, List[str], Callable, Any) -> Dict[str, Attribute] # NOQA
"""Get members and attributes of target object."""
# the members directly defined in the class
obj_dict = attrgetter(subject, '__dict__', {})
- # Py34 doesn't have enum members in __dict__.
- if sys.version_info[:2] == (3, 4) and isenumclass(subject):
- obj_dict = dict(obj_dict)
- for name, value in subject.__members__.items():
- obj_dict[name] = value
-
members = {} # type: Dict[str, Attribute]
# enum members
@@ -249,7 +300,7 @@ def get_object_members(subject, objpath, attrgetter, analyzer=None):
members[name] = Attribute(name, True, value)
superclass = subject.__mro__[1]
- for name, value in iteritems(obj_dict):
+ for name, value in obj_dict.items():
if name not in superclass.__dict__:
members[name] = Attribute(name, True, value)