summaryrefslogtreecommitdiff
path: root/setuptools/tests/contexts.py
blob: 112cdf4b288fe102c2d4432ce1f6cf2b34468fd8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import tempfile
import os
import shutil
import sys
import contextlib
import site
import io

from filelock import FileLock


@contextlib.contextmanager
def tempdir(cd=lambda dir: None, **kwargs):
    temp_dir = tempfile.mkdtemp(**kwargs)
    orig_dir = os.getcwd()
    try:
        cd(temp_dir)
        yield temp_dir
    finally:
        cd(orig_dir)
        shutil.rmtree(temp_dir)


@contextlib.contextmanager
def environment(**replacements):
    """
    In a context, patch the environment with replacements. Pass None values
    to clear the values.
    """
    saved = dict((key, os.environ[key]) for key in replacements if key in os.environ)

    # remove values that are null
    remove = (key for (key, value) in replacements.items() if value is None)
    for key in list(remove):
        os.environ.pop(key, None)
        replacements.pop(key)

    os.environ.update(replacements)

    try:
        yield saved
    finally:
        for key in replacements:
            os.environ.pop(key, None)
        os.environ.update(saved)


@contextlib.contextmanager
def quiet():
    """
    Redirect stdout/stderr to StringIO objects to prevent console output from
    distutils commands.
    """

    old_stdout = sys.stdout
    old_stderr = sys.stderr
    new_stdout = sys.stdout = io.StringIO()
    new_stderr = sys.stderr = io.StringIO()
    try:
        yield new_stdout, new_stderr
    finally:
        new_stdout.seek(0)
        new_stderr.seek(0)
        sys.stdout = old_stdout
        sys.stderr = old_stderr


@contextlib.contextmanager
def save_user_site_setting():
    saved = site.ENABLE_USER_SITE
    try:
        yield saved
    finally:
        site.ENABLE_USER_SITE = saved


@contextlib.contextmanager
def save_pkg_resources_state():
    import pkg_resources

    pr_state = pkg_resources.__getstate__()
    # also save sys.path
    sys_path = sys.path[:]
    try:
        yield pr_state, sys_path
    finally:
        sys.path[:] = sys_path
        pkg_resources.__setstate__(pr_state)


@contextlib.contextmanager
def suppress_exceptions(*excs):
    try:
        yield
    except excs:
        pass


def multiproc(request):
    """
    Return True if running under xdist and multiple
    workers are used.
    """
    try:
        worker_id = request.getfixturevalue('worker_id')
    except Exception:
        return False
    return worker_id != 'master'


@contextlib.contextmanager
def session_locked_tmp_dir(request, tmp_path_factory, name):
    """Uses a file lock to guarantee only one worker can access a temp dir"""
    # get the temp directory shared by all workers
    base = tmp_path_factory.getbasetemp()
    shared_dir = base.parent if multiproc(request) else base

    locked_dir = shared_dir / name
    with FileLock(locked_dir.with_suffix(".lock")):
        # ^-- prevent multiple workers to access the directory at once
        locked_dir.mkdir(exist_ok=True, parents=True)
        yield locked_dir


@contextlib.contextmanager
def save_paths():
    """Make sure ``sys.path``, ``sys.meta_path`` and ``sys.path_hooks`` are preserved"""
    prev = sys.path[:], sys.meta_path[:], sys.path_hooks[:]

    try:
        yield
    finally:
        sys.path, sys.meta_path, sys.path_hooks = prev


@contextlib.contextmanager
def save_sys_modules():
    """Make sure initial ``sys.modules`` is preserved"""
    prev_modules = sys.modules

    try:
        sys.modules = sys.modules.copy()
        yield
    finally:
        sys.modules = prev_modules