summaryrefslogtreecommitdiff
path: root/sphinx/util/requests.py
blob: c64754fa2a11d81c56b2e1e5fed9ccc7cd5a7324 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Simple requests package loader"""

from __future__ import annotations

import sys
import warnings
from contextlib import contextmanager
from typing import Any, Generator
from urllib.parse import urlsplit

import requests
from urllib3.exceptions import InsecureRequestWarning

import sphinx
from sphinx.config import Config

useragent_header = [('User-Agent',
                     'Mozilla/5.0 (X11; Linux x86_64; rv:25.0) Gecko/20100101 Firefox/25.0')]


@contextmanager
def ignore_insecure_warning(**kwargs: Any) -> Generator[None, None, None]:
    with warnings.catch_warnings():
        if not kwargs.get('verify'):
            # ignore InsecureRequestWarning if verify=False
            warnings.filterwarnings("ignore", category=InsecureRequestWarning)
        yield


def _get_tls_cacert(url: str, config: Config) -> str | bool:
    """Get additional CA cert for a specific URL.

    This also returns ``False`` if verification is disabled.
    And returns ``True`` if additional CA cert not found.
    """
    if not config.tls_verify:
        return False

    certs = getattr(config, 'tls_cacerts', None)
    if not certs:
        return True
    elif isinstance(certs, (str, tuple)):
        return certs  # type: ignore
    else:
        hostname = urlsplit(url)[1]
        if '@' in hostname:
            hostname = hostname.split('@')[1]

        return certs.get(hostname, True)


def _get_user_agent(config: Config) -> str:
    if config.user_agent:
        return config.user_agent
    else:
        return ' '.join([
            'Sphinx/%s' % sphinx.__version__,
            'requests/%s' % requests.__version__,
            'python/%s' % '.'.join(map(str, sys.version_info[:3])),
        ])


def get(url: str, **kwargs: Any) -> requests.Response:
    """Sends a GET request like requests.get().

    This sets up User-Agent header and TLS verification automatically."""
    headers = kwargs.setdefault('headers', {})
    config = kwargs.pop('config', None)
    if config:
        kwargs.setdefault('verify', _get_tls_cacert(url, config))
        headers.setdefault('User-Agent', _get_user_agent(config))
    else:
        headers.setdefault('User-Agent', useragent_header[0][1])

    with ignore_insecure_warning(**kwargs):
        return requests.get(url, **kwargs)


def head(url: str, **kwargs: Any) -> requests.Response:
    """Sends a HEAD request like requests.head().

    This sets up User-Agent header and TLS verification automatically."""
    headers = kwargs.setdefault('headers', {})
    config = kwargs.pop('config', None)
    if config:
        kwargs.setdefault('verify', _get_tls_cacert(url, config))
        headers.setdefault('User-Agent', _get_user_agent(config))
    else:
        headers.setdefault('User-Agent', useragent_header[0][1])

    with ignore_insecure_warning(**kwargs):
        return requests.head(url, **kwargs)