summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-01-07 13:03:13 -0800
committerBrian Wellington <bwelling@xbill.org>2020-01-07 13:03:13 -0800
commit7ec39e21ab0a6761a34ec405a4a59dc4ebe54924 (patch)
tree61a5e0db5d3e733e95c87d35ed2c04a73e99f3ba
parent0ed99480529ecf3217738fe671a31dddd3360e48 (diff)
downloaddnspython-7ec39e21ab0a6761a34ec405a4a59dc4ebe54924.tar.gz
DoH cleanup.
-rw-r--r--dns/query.py104
-rw-r--r--dns/query.pyi5
-rw-r--r--dns/resolver.py21
-rw-r--r--examples/doh.py4
-rw-r--r--tests/test_doh.py16
5 files changed, 97 insertions, 53 deletions
diff --git a/dns/query.py b/dns/query.py
index c36248e..5876623 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -189,14 +189,16 @@ def _addresses_equal(af, a1, a2):
return n1 == n2 and a1[1:] == a2[1:]
-def _destination_and_source(af, where, port, source, source_port):
+def _destination_and_source(af, where, port, source, source_port,
+ default_to_inet=True):
# Apply defaults and compute destination and source tuples
# suitable for use in connect(), sendto(), or bind().
if af is None:
try:
af = dns.inet.af_for_address(where)
except Exception:
- af = dns.inet.AF_INET
+ if default_to_inet:
+ af = dns.inet.AF_INET
if af == dns.inet.AF_INET:
destination = (where, port)
if source is not None or source_port != 0:
@@ -209,6 +211,9 @@ def _destination_and_source(af, where, port, source, source_port):
if source is None:
source = '::'
source = (source, source_port, 0, 0)
+ else:
+ source = None
+ destination = None
return (af, destination, source)
def send_https(session, what, lifetime=None):
@@ -225,9 +230,10 @@ def send_https(session, what, lifetime=None):
what = what.prepare()
return session.send(what, timeout=lifetime)
-def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
- bootstrap_address=None, verify=True, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False):
+def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ session=None, path='/dns-query', post=True,
+ bootstrap_address=None, verify=True):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
@@ -236,21 +242,15 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
address is given, the URL will be constructed using the following schema:
https://<IP-address>:<port>/<path>.
- *session*, a ``requests.session.Session``, the session to use to send the
- queries. This argument is required to allow for connection reuse.
-
*timeout*, a ``float`` or ``None``, the number of seconds to
wait before the query times out. If ``None``, the default, wait forever.
- *port*, a ``int``, the port to send the query to. Default is 443.
-
- *path*, a ``str``. If *where* is an IP address, then *path* will be used to
- construct the URL to send the DNS query to.
-
- *post*, a ``bool``. If ``True``, the default, POST method will be used.
+ *port*, a ``int``, the port to send the query to. The default is 443.
- *bootstrap_address*, a ``str``, the IP address to use to bypass the system's
- DNS resolver.
+ *af*, an ``int``, the address family to use. The default is ``None``,
+ which causes the address family to use to be inferred from the form of
+ *where*, or uses the system default. Setting this to AF_INET or
+ AF_INET6 currently has no effect.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
@@ -264,13 +264,27 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *session*, a ``requests.session.Session``. If provided, the session to use
+ to send the queries.
+
+ *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+ construct the URL to send the DNS query to.
+
+ *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
+ *bootstrap_address*, a ``str``, the IP address to use to bypass the
+ system's DNS resolver.
+
+ *verify*, a ``str`, containing a path to a certificate file or directory.
+
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
- af = None
(af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
+ source, source_port,
+ False)
+ transport_adapter = None
headers = {
"accept": "application/dns-message"
}
@@ -282,31 +296,49 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
split_url = urllib.parse.urlsplit(where)
headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
- session.mount(url, HostHeaderSSLAdapter())
+ transport_adapter = HostHeaderSSLAdapter()
else:
url = where
if source is not None:
# set source port and source address
- session.mount(url, SourceAddressAdapter(source))
-
- # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
- if post:
- headers.update({
- "content-type": "application/dns-message",
- "content-length": str(len(wire))
- })
- response = session.post(url, headers=headers, data=wire, stream=True,
- timeout=timeout, verify=verify)
+ transport_adapter = SourceAddressAdapter(source)
+
+ if session:
+ close_session = False
else:
- wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
- url += "?dns={}".format(wire)
- response = session.get(url, headers=headers, stream=True,
- timeout=timeout, verify=verify)
+ session = requests.sessions.Session()
+ close_session = True
+
+ try:
+ if transport_adapter:
+ session.mount(url, transport_adapter)
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+ # GET and POST examples
+ if post:
+ headers.update({
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire))
+ })
+ response = session.post(url, headers=headers, data=wire,
+ stream=True, timeout=timeout,
+ verify=verify)
+ else:
+ wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+ url += "?dns={}".format(wire)
+ response = session.get(url, headers=headers, stream=True,
+ timeout=timeout, verify=verify)
+ finally:
+ if close_session:
+ session.close()
- # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+ # status codes
if response.status_code < 200 or response.status_code > 299:
- raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
- where, response.status_code, response.content))
+ raise ValueError('{} responded with status code {}'
+ '\nResponse body: {}'.format(where,
+ response.status_code,
+ response.content))
r = dns.message.from_wire(response.content,
keyring=q.keyring,
request_mac=q.request_mac,
diff --git a/dns/query.pyi b/dns/query.pyi
index e943ba5..9346123 100644
--- a/dns/query.pyi
+++ b/dns/query.pyi
@@ -8,9 +8,8 @@ except ImportError:
class ssl(object):
SSLContext = {}
-def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True,
- bootstrap_address : Optional[str] = None, verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message:
+def https(q : message.Message, where: str, timeout : Optional[float] = None, port : Optional[int] = 443, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0,
+ session: Optional[Session], path : Optional[str] = '/dns-query', post : Optional[bool] = True, bootstrap_address : Optional[str] = None, verify : Optional[bool] = True) -> message.Message:
pass
def tcp(q : message.Message, where : str, timeout : float = None, port=53, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0,
diff --git a/dns/resolver.py b/dns/resolver.py
index 735de9f..3f5e451 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -909,30 +909,37 @@ class Resolver(object):
try:
if protocol == 'https':
tcp_attempt = True
- response = dns.query.https(request, nameserver, timeout)
+ response = dns.query.https(request, nameserver,
+ timeout=timeout)
elif protocol:
continue
else:
tcp_attempt = tcp
if tcp:
response = dns.query.tcp(request, nameserver,
- timeout, port,
+ timeout=timeout,
+ port=port,
source=source,
- source_port=source_port)
+ source_port=\
+ source_port)
else:
try:
- response = dns.query.udp(request, nameserver,
- timeout, port,
+ response = dns.query.udp(request,
+ nameserver,
+ timeout=timeout,
+ port=port,
source=source,
source_port=\
source_port)
except dns.message.Truncated:
# Response truncated; retry with TCP.
tcp_attempt = True
- timeout = self._compute_timeout(start, lifetime)
+ timeout = self._compute_timeout(start,
+ lifetime)
response = \
dns.query.tcp(request, nameserver,
- timeout, port,
+ timeout=timeout,
+ port=port,
source=source,
source_port=source_port)
except (socket.error, dns.exception.Timeout) as ex:
diff --git a/examples/doh.py b/examples/doh.py
index 01c562f..eff9ae7 100644
--- a/examples/doh.py
+++ b/examples/doh.py
@@ -18,7 +18,7 @@ def main():
# one method is to use context manager, session will automatically close
with requests.sessions.Session() as session:
q = dns.message.make_query(qname, dns.rdatatype.A)
- r = dns.query.https(q, where, session)
+ r = dns.query.https(q, where, session=session)
for answer in r.answer:
print(answer)
@@ -29,7 +29,7 @@ def main():
# second method, close session manually
session = requests.sessions.Session()
q = dns.message.make_query(qname, dns.rdatatype.A)
- r = dns.query.https(q, where, session)
+ r = dns.query.https(q, where, session=session)
for answer in r.answer:
print(answer)
diff --git a/tests/test_doh.py b/tests/test_doh.py
index 3819b1a..acda5af 100644
--- a/tests/test_doh.py
+++ b/tests/test_doh.py
@@ -40,13 +40,13 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
def test_get_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = dns.query.https(q, nameserver_url, self.session, post=False)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=False)
self.assertTrue(q.is_response(r))
def test_post_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = dns.query.https(q, nameserver_url, self.session, post=True)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=True)
self.assertTrue(q.is_response(r))
def test_build_url_from_ip(self):
@@ -54,7 +54,7 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
q = dns.message.make_query('example.com.', dns.rdatatype.A)
# For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
# So we're just going to do GET requests here
- r = dns.query.https(q, nameserver_ip, self.session, post=False)
+ r = dns.query.https(q, nameserver_ip, session=self.session, post=False)
self.assertTrue(q.is_response(r))
def test_bootstrap_address(self):
@@ -64,9 +64,9 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
q = dns.message.make_query('example.com.', dns.rdatatype.A)
# make sure CleanBrowsing's IP address will fail TLS certificate check
with self.assertRaises(SSLError):
- dns.query.https(q, invalid_tls_url, self.session)
+ dns.query.https(q, invalid_tls_url, session=self.session)
# use host header
- r = dns.query.https(q, valid_tls_url, self.session, bootstrap_address=ip)
+ r = dns.query.https(q, valid_tls_url, session=self.session, bootstrap_address=ip)
self.assertTrue(q.is_response(r))
def test_send_https(self):
@@ -79,5 +79,11 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
dns_resp = dns.message.from_wire(response.content)
self.assertTrue(q.is_response(dns_resp))
+ def test_new_session(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_url)
+ self.assertTrue(q.is_response(r))
+
if __name__ == '__main__':
unittest.main()