summaryrefslogtreecommitdiff
path: root/coverage/misc.py
diff options
context:
space:
mode:
Diffstat (limited to 'coverage/misc.py')
-rw-r--r--coverage/misc.py38
1 files changed, 28 insertions, 10 deletions
diff --git a/coverage/misc.py b/coverage/misc.py
index 9c414d88..29397537 100644
--- a/coverage/misc.py
+++ b/coverage/misc.py
@@ -3,6 +3,7 @@
"""Miscellaneous stuff for coverage.py."""
+import contextlib
import errno
import hashlib
import importlib
@@ -49,6 +50,28 @@ def isolate_module(mod):
os = isolate_module(os)
+class SysModuleSaver:
+ """Saves the contents of sys.modules, and removes new modules later."""
+ def __init__(self):
+ self.old_modules = set(sys.modules)
+
+ def restore(self):
+ """Remove any modules imported since this object started."""
+ new_modules = set(sys.modules) - self.old_modules
+ for m in new_modules:
+ del sys.modules[m]
+
+
+@contextlib.contextmanager
+def sys_modules_saved():
+ """A context manager to remove any modules imported during a block."""
+ saver = SysModuleSaver()
+ try:
+ yield
+ finally:
+ saver.restore()
+
+
def import_third_party(modname):
"""Import a third-party module we need, but might not be installed.
@@ -63,16 +86,11 @@ def import_third_party(modname):
The imported module, or None if the module couldn't be imported.
"""
- try:
- mod = importlib.import_module(modname)
- except ImportError:
- mod = None
-
- imported = [m for m in sys.modules if m.startswith(modname)]
- for name in imported:
- del sys.modules[name]
-
- return mod
+ with sys_modules_saved():
+ try:
+ return importlib.import_module(modname)
+ except ImportError:
+ return None
def dummy_decorator_with_args(*args_unused, **kwargs_unused):