summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2023-02-25 11:43:26 -0800
committerGitHub <noreply@github.com>2023-02-25 11:43:26 -0800
commitf7daeb87eac0a2727d5366cdff02fe08843678dd (patch)
treec38eab2e322ec02cde716a1aa9b70271a4594174
parentc76de3e1f2de416694353aa158545689f72cedca (diff)
downloaddnspython-f7daeb87eac0a2727d5366cdff02fe08843678dd.tar.gz
Resolver "nameserver" object support. (#897)
* Resolver "nameserver" object support. This turns the list of nameserver strings in the resolver into a tuple of nameserver objects, which abstract away making queries to a nameserver of a given type. The resolver's legacy nameserver list is "enriched" into a tuple of nameserver objects whenever it is set. Note that you cannot mutate the object other than by setting, e.g. res.nameservers.append("1.2.3.4") will not work. Error message accumulation has been updated to refer to the nameservers using a descriptive text form. * doco fix * more doco fixes * do enrichment at Resolution time * require a later mypy, fix type issues * add nameserver doc
-rw-r--r--dns/asyncresolver.py36
-rw-r--r--dns/nameserver.py315
-rw-r--r--dns/resolver.py181
-rw-r--r--doc/resolver-class.rst9
-rw-r--r--doc/resolver-nameserver.rst46
-rw-r--r--doc/resolver.rst1
-rw-r--r--doc/whatsnew.rst6
-rw-r--r--pyproject.toml2
-rw-r--r--tests/test_resolution.py91
-rw-r--r--tests/test_resolver.py28
10 files changed, 566 insertions, 149 deletions
diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py
index 506530e..9ba84de 100644
--- a/dns/asyncresolver.py
+++ b/dns/asyncresolver.py
@@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
assert request is not None # needed for type checking
done = False
while not done:
- (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+ (nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff:
await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
- if dns.inet.is_address(nameserver):
- if tcp:
- response = await _tcp(
- request,
- nameserver,
- timeout,
- port,
- source,
- source_port,
- backend=backend,
- )
- else:
- response = await _udp(
- request,
- nameserver,
- timeout,
- port,
- source,
- source_port,
- raise_on_truncation=True,
- backend=backend,
- )
- else:
- response = await dns.asyncquery.https(
- request, nameserver, timeout=timeout
- )
+ response = await nameserver.async_query(
+ request,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ max_size=tcp,
+ backend=backend,
+ )
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
diff --git a/dns/nameserver.py b/dns/nameserver.py
new file mode 100644
index 0000000..7de0abb
--- /dev/null
+++ b/dns/nameserver.py
@@ -0,0 +1,315 @@
+from urllib.parse import urlparse
+
+from typing import Optional, Union
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.inet
+import dns.message
+import dns.query
+
+
+class Nameserver:
+ def __init__(self):
+ pass
+
+ def __str__(self):
+ raise NotImplementedError
+
+ def is_always_max_size(self) -> bool:
+ raise NotImplementedError
+
+ def answer_nameserver(self) -> str:
+ raise NotImplementedError
+
+ def answer_port(self) -> int:
+ raise NotImplementedError
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ raise NotImplementedError
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ raise NotImplementedError
+
+
+class AddressAndPortNameserver(Nameserver):
+ def __init__(self, address: str, port: int):
+ super().__init__()
+ self.address = address
+ self.port = port
+
+ def kind(self) -> str:
+ raise NotImplementedError
+
+ def is_always_max_size(self) -> bool:
+ return False
+
+ def __str__(self):
+ ns_kind = self.kind()
+ return f"{ns_kind}:{self.address}@{self.port}"
+
+ def answer_nameserver(self) -> str:
+ return self.address
+
+ def answer_port(self) -> int:
+ return self.port
+
+
+class Do53Nameserver(AddressAndPortNameserver):
+ def __init__(self, address: str, port: int = 53):
+ super().__init__(address, port)
+
+ def kind(self):
+ return "Do53"
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ if max_size:
+ response = dns.query.tcp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ else:
+ response = dns.query.udp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ raise_on_truncation=True,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ return response
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ if max_size:
+ response = await dns.asyncquery.tcp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ backend=backend,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ else:
+ response = await dns.asyncquery.udp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ raise_on_truncation=True,
+ backend=backend,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ return response
+
+
+class DoHNameserver(Nameserver):
+ def __init__(self, url: str, bootstrap_address: Optional[str] = None):
+ super().__init__()
+ self.url = url
+ self.bootstrap_address = bootstrap_address
+
+ def is_always_max_size(self) -> bool:
+ return True
+
+ def __str__(self):
+ return self.url
+
+ def answer_nameserver(self) -> str:
+ return self.url
+
+ def answer_port(self) -> int:
+ port = urlparse(self.url).port
+ if port is None:
+ port = 443
+ return port
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.https(
+ request,
+ self.url,
+ timeout=timeout,
+ bootstrap_address=self.bootstrap_address,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.https(
+ request,
+ self.url,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+
+
+class DoTNameserver(AddressAndPortNameserver):
+ def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
+ super().__init__(address, port)
+ self.hostname = hostname
+
+ def kind(self):
+ return "DoT"
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.tls(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ server_hostname=self.hostname,
+ )
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.tls(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ server_hostname=self.hostname,
+ )
+
+
+class DoQNameserver(AddressAndPortNameserver):
+ def __init__(self, address: str, port: int = 853, verify: Union[bool, str] = True):
+ super().__init__(address, port)
+ self.verify = verify
+
+ def kind(self):
+ return "DoQ"
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.quic(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ )
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.quic(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ )
diff --git a/dns/resolver.py b/dns/resolver.py
index 5ba8601..4fc5bfd 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -36,6 +36,7 @@ import dns.ipv4
import dns.ipv6
import dns.message
import dns.name
+import dns.nameserver
import dns.query
import dns.rcode
import dns.rdataclass
@@ -140,7 +141,11 @@ class YXDOMAIN(dns.exception.DNSException):
ErrorTuple = Tuple[
- Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]
+ Optional[str],
+ bool,
+ int,
+ Union[Exception, str],
+ Optional[dns.message.Message],
]
@@ -148,11 +153,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text."""
texts = []
for err in errors:
- texts.append(
- "Server {} {} port {} answered {}".format(
- err[0], "TCP" if err[1] else "UDP", err[2], err[3]
- )
- )
+ texts.append("Server {} answered {}".format(err[0], err[3]))
return texts
@@ -377,7 +378,7 @@ class Cache(CacheBase):
now = time.time()
if self.next_cleaning <= now:
keys_to_delete = []
- for (k, v) in self.data.items():
+ for k, v in self.data.items():
if v.expiration <= now:
keys_to_delete.append(k)
for k in keys_to_delete:
@@ -609,11 +610,10 @@ class _Resolution:
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
# Initialize other things to help analysis tools
self.qname = dns.name.empty
- self.nameservers: List[str] = []
- self.current_nameservers: List[str] = []
+ self.nameservers: List[dns.nameserver.Nameserver] = []
+ self.current_nameservers: List[dns.nameserver.Nameserver] = []
self.errors: List[ErrorTuple] = []
- self.nameserver: Optional[str] = None
- self.port = 0
+ self.nameserver: Optional[dns.nameserver.Nameserver] = None
self.tcp_attempt = False
self.retry_with_tcp = False
self.request: Optional[dns.message.QueryMessage] = None
@@ -670,7 +670,9 @@ class _Resolution:
if self.resolver.flags is not None:
request.flags = self.resolver.flags
- self.nameservers = self.resolver.nameservers[:]
+ self.nameservers = self.resolver._enrich_nameservers(
+ self.resolver._nameservers
+ )
if self.resolver.rotate:
random.shuffle(self.nameservers)
self.current_nameservers = self.nameservers[:]
@@ -690,12 +692,13 @@ class _Resolution:
#
raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
- def next_nameserver(self) -> Tuple[str, int, bool, float]:
+ def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
if self.retry_with_tcp:
assert self.nameserver is not None
+ assert not self.nameserver.is_always_max_size()
self.tcp_attempt = True
self.retry_with_tcp = False
- return (self.nameserver, self.port, True, 0)
+ return (self.nameserver, True, 0)
backoff = 0.0
if not self.current_nameservers:
@@ -707,11 +710,8 @@ class _Resolution:
self.backoff = min(self.backoff * 2, 2)
self.nameserver = self.current_nameservers.pop(0)
- self.port = self.resolver.nameserver_ports.get(
- self.nameserver, self.resolver.port
- )
- self.tcp_attempt = self.tcp
- return (self.nameserver, self.port, self.tcp_attempt, backoff)
+ self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
+ return (self.nameserver, self.tcp_attempt, backoff)
def query_result(
self, response: Optional[dns.message.Message], ex: Optional[Exception]
@@ -724,7 +724,13 @@ class _Resolution:
# Exception during I/O or from_wire()
assert response is None
self.errors.append(
- (self.nameserver, self.tcp_attempt, self.port, ex, response)
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ ex,
+ response,
+ )
)
if (
isinstance(ex, dns.exception.FormError)
@@ -752,12 +758,18 @@ class _Resolution:
self.rdtype,
self.rdclass,
response,
- self.nameserver,
- self.port,
+ self.nameserver.answer_nameserver(),
+ self.nameserver.answer_port(),
)
except Exception as e:
self.errors.append(
- (self.nameserver, self.tcp_attempt, self.port, e, response)
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ e,
+ response,
+ )
)
# The nameserver is no good, take it out of the mix.
self.nameservers.remove(self.nameserver)
@@ -776,7 +788,13 @@ class _Resolution:
)
except Exception as e:
self.errors.append(
- (self.nameserver, self.tcp_attempt, self.port, e, response)
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ e,
+ response,
+ )
)
# The nameserver is no good, take it out of the mix.
self.nameservers.remove(self.nameserver)
@@ -792,7 +810,13 @@ class _Resolution:
elif rcode == dns.rcode.YXDOMAIN:
yex = YXDOMAIN()
self.errors.append(
- (self.nameserver, self.tcp_attempt, self.port, yex, response)
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ yex,
+ response,
+ )
)
raise yex
else:
@@ -804,9 +828,9 @@ class _Resolution:
self.nameservers.remove(self.nameserver)
self.errors.append(
(
- self.nameserver,
+ str(self.nameserver),
self.tcp_attempt,
- self.port,
+ self.nameserver.answer_port(),
dns.rcode.to_text(rcode),
response,
)
@@ -840,6 +864,7 @@ class BaseResolver:
retry_servfail: bool
rotate: bool
ndots: Optional[int]
+ _nameservers: List[Union[str, dns.nameserver.Nameserver]]
def __init__(
self, filename: str = "/etc/resolv.conf", configure: bool = True
@@ -868,7 +893,7 @@ class BaseResolver:
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
if len(self.domain) == 0:
self.domain = dns.name.root
- self.nameservers = []
+ self._nameservers = []
self.nameserver_ports = {}
self.port = 53
self.search = []
@@ -905,6 +930,7 @@ class BaseResolver:
"""
+ nameservers = []
if isinstance(f, str):
try:
cm: contextlib.AbstractContextManager = open(f)
@@ -924,7 +950,7 @@ class BaseResolver:
continue
if tokens[0] == "nameserver":
- self.nameservers.append(tokens[1])
+ nameservers.append(tokens[1])
elif tokens[0] == "domain":
self.domain = dns.name.from_text(tokens[1])
# domain and search are exclusive
@@ -952,8 +978,11 @@ class BaseResolver:
self.ndots = int(opt.split(":")[1])
except (ValueError, IndexError):
pass
- if len(self.nameservers) == 0:
+ if len(nameservers) == 0:
raise NoResolverConfiguration("no nameservers")
+ # Assigning directly instead of appending means we invoke the
+ # setter logic, with additonal checking and enrichment.
+ self.nameservers = nameservers
def read_registry(self) -> None:
"""Extract resolver configuration from the Windows registry."""
@@ -1088,34 +1117,60 @@ class BaseResolver:
self.flags = flags
- @property
- def nameservers(self) -> List[str]:
- return self._nameservers
-
- @nameservers.setter
- def nameservers(self, nameservers: List[str]) -> None:
- """
- *nameservers*, a ``list`` of nameservers.
-
- Raises ``ValueError`` if *nameservers* is anything other than a
- ``list``.
- """
+ def _enrich_nameservers(
+ self, nameservers: List[Union[str, dns.nameserver.Nameserver]]
+ ) -> List[dns.nameserver.Nameserver]:
+ enriched_nameservers = []
if isinstance(nameservers, list):
for nameserver in nameservers:
- if not dns.inet.is_address(nameserver):
+ enriched_nameserver: dns.nameserver.Nameserver
+ if isinstance(nameserver, dns.nameserver.Nameserver):
+ enriched_nameserver = nameserver
+ elif dns.inet.is_address(nameserver):
+ port = self.nameserver_ports.get(nameserver, self.port)
+ enriched_nameserver = dns.nameserver.Do53Nameserver(
+ nameserver, port
+ )
+ else:
try:
if urlparse(nameserver).scheme != "https":
raise NotImplementedError
except Exception:
raise ValueError(
- f"nameserver {nameserver} is not an "
- "IP address or valid https URL"
+ f"nameserver {nameserver} is not a "
+ "dns.nameserver.Nameserver instance or text form, "
+ "IP address, nor a valid https URL"
)
- self._nameservers = nameservers
+ enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
+ enriched_nameservers.append(enriched_nameserver)
else:
raise ValueError(
- "nameservers must be a list (not a {})".format(type(nameservers))
+ "nameservers must be a list or tuple (not a {})".format(
+ type(nameservers)
+ )
)
+ return enriched_nameservers
+
+ @property
+ def nameservers(
+ self,
+ ) -> List[Union[str, dns.nameserver.Nameserver]]:
+ return self._nameservers
+
+ @nameservers.setter
+ def nameservers(
+ self, nameservers: List[Union[str, dns.nameserver.Nameserver]]
+ ) -> None:
+ """
+ *nameservers*, a ``list`` of nameservers, where a nameserver is either
+ a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
+ instance.
+
+ Raises ``ValueError`` if *nameservers* is not a list of nameservers.
+ """
+ # We just call _enrich_nameservers() for checking
+ self._enrich_nameservers(nameservers)
+ self._nameservers = nameservers
class Resolver(BaseResolver):
@@ -1200,33 +1255,18 @@ class Resolver(BaseResolver):
assert request is not None # needed for type checking
done = False
while not done:
- (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+ (nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff:
time.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
- if dns.inet.is_address(nameserver):
- if tcp:
- response = dns.query.tcp(
- request,
- nameserver,
- timeout=timeout,
- port=port,
- source=source,
- source_port=source_port,
- )
- else:
- response = dns.query.udp(
- request,
- nameserver,
- timeout=timeout,
- port=port,
- source=source,
- source_port=source_port,
- raise_on_truncation=True,
- )
- else:
- response = dns.query.https(request, nameserver, timeout=timeout)
+ response = nameserver.query(
+ request,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ max_size=tcp,
+ )
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
@@ -1357,7 +1397,6 @@ def resolve(
lifetime: Optional[float] = None,
search: Optional[bool] = None,
) -> Answer: # pragma: no cover
-
"""Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver
diff --git a/doc/resolver-class.rst b/doc/resolver-class.rst
index 5bf01e3..21c6c46 100644
--- a/doc/resolver-class.rst
+++ b/doc/resolver-class.rst
@@ -12,11 +12,12 @@ The dns.resolver.Resolver and dns.resolver.Answer Classes
.. attribute:: nameservers
- A ``list`` of ``str``, each item containing an IPv4 or IPv6 address.
+ A ``list`` of ``str`` or ``dns.nameserver.Nameserver``. A string may be
+ an IPv4 or IPv6 address, or an https URL.
- This field is planned to become a property in dnspython 2.4. Writing to this
- field other than by direct assignment is deprecated, and so is depending on the
- mutability and form of the iterable returned when it is read.
+ This field is actually a property, and returns a tuple as of dnspython 2.4.
+ Assigning this this field converts any strings into
+ ``dns.nameserver.Nameserver`` instances.
.. attribute:: search
diff --git a/doc/resolver-nameserver.rst b/doc/resolver-nameserver.rst
new file mode 100644
index 0000000..06f4a1b
--- /dev/null
+++ b/doc/resolver-nameserver.rst
@@ -0,0 +1,46 @@
+.. _resolver-nameserver:
+
+The dns.nameserver.Nameserver Classes
+-------------------------------------
+
+The ``dns.nameserver.Nameserver`` abstract class represents a remote recursive resolver,
+and is used by the stub resolver to answer queries.
+
+.. autoclass:: dns.nameserver.Nameserver
+ :members:
+
+The dns.nameserver.Do53Nameserver Class
+---------------------------------------
+
+The ``dns.nameserver.Do53Nameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make regular port 53 (Do53) DNS queries to a recursive server.
+
+.. autoclass:: dns.nameserver.Do53Nameserver
+ :members:
+
+The dns.nameserver.DoTNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoTNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-TLS (DoT) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoTNameserver
+ :members:
+
+The dns.nameserver.DoHNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoHNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-HTTPS (DoH) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoHNameserver
+ :members:
+
+The dns.nameserver.DoQNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoQNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-QUIC (DoQ) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoQNameserver
+ :members:
diff --git a/doc/resolver.rst b/doc/resolver.rst
index e9cf7b2..138ac3e 100644
--- a/doc/resolver.rst
+++ b/doc/resolver.rst
@@ -13,6 +13,7 @@ be used simply by setting the *nameservers* attribute.
.. toctree::
resolver-class
+ resolver-nameserver
resolver-functions
resolver-caching
resolver-override
diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst
index 54d3847..95fa691 100644
--- a/doc/whatsnew.rst
+++ b/doc/whatsnew.rst
@@ -6,6 +6,12 @@ What's New in dnspython
2.4.0 (in development)
----------------------
+* The stub resolver now uses instances of ``dns.nameserver.Nameserver`` to represent
+ remote recursive resolvers, and can communicate using
+ DNS over port 53, HTTPS, TLS, and QUIC. In additional to being able to specify
+ an IPv4, IPv6, or HTTPS URL as a nameserver, instances of ``dns.nameserver.Nameserver``
+ are now permitted.
+
2.3.0
-----
diff --git a/pyproject.toml b/pyproject.toml
index 1703a7f..deb52f7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -60,7 +60,7 @@ coverage = "^7.0"
twine = "^4.0.0"
wheel = "^0.38.1"
pylint = "^2.7.4"
-mypy = ">=0.940"
+mypy = ">=1.0.1"
black = "^23.1.0"
[tool.poetry.extras]
diff --git a/tests/test_resolution.py b/tests/test_resolution.py
index d2819a1..d8bdb2c 100644
--- a/tests/test_resolution.py
+++ b/tests/test_resolution.py
@@ -222,8 +222,8 @@ class ResolutionTestCase(unittest.TestCase):
def test_next_request_rotate(self):
self.resolver.rotate = True
- order1 = ["10.0.0.1", "10.0.0.2"]
- order2 = ["10.0.0.2", "10.0.0.1"]
+ order1 = ["Do53:10.0.0.1@53", "Do53:10.0.0.2@53"]
+ order2 = ["Do53:10.0.0.2@53", "Do53:10.0.0.1@53"]
seen1 = False
seen2 = False
# We're not interested in testing the randomness, but we'd
@@ -235,9 +235,11 @@ class ResolutionTestCase(unittest.TestCase):
self.resolver, self.qname, "A", "IN", False, True, False
)
self.resn.next_request()
- if self.resn.nameservers == order1:
+ text_form = [str(n) for n in self.resn.nameservers]
+ print(text_form)
+ if text_form == order1:
seen1 = True
- elif self.resn.nameservers == order2:
+ elif text_form == order2:
seen2 = True
else:
raise ValueError # should not happen!
@@ -264,68 +266,71 @@ class ResolutionTestCase(unittest.TestCase):
def test_next_nameserver_udp(self):
(request, answer) = self.resn.next_request()
- (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in self.resolver.nameservers)
- self.assertEqual(port, 53)
+ (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+ self.assertEqual(nameserver1.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver2 in self.resolver.nameservers)
+ (nameserver2, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver2 != nameserver1)
- self.assertEqual(port, 53)
+ self.assertEqual(nameserver2.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver3, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver3 is nameserver1)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.1)
- (nameserver4, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver4, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver4 is nameserver2)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver5, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver5, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver5 is nameserver1)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.2)
def test_next_nameserver_retry_with_tcp(self):
(request, answer) = self.resn.next_request()
- (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in self.resolver.nameservers)
- self.assertEqual(port, 53)
+ (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+ self.assertEqual(nameserver1.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
self.resn.retry_with_tcp = True
- (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver2, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver2 is nameserver1)
- self.assertEqual(port, 53)
self.assertTrue(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver3 in self.resolver.nameservers)
+ (nameserver3, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver3 != nameserver1)
- self.assertEqual(port, 53)
+ self.assertEqual(nameserver3.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
def test_next_nameserver_no_nameservers(self):
(request, answer) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.resn.nameservers.remove(nameserver)
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.resn.nameservers.remove(nameserver)
def bad():
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.assertRaises(dns.resolver.NoNameservers, bad)
+ def test_next_nameserver_max_size_nameserver(self):
+ # A query to a nameserver that always supports a maximum size query
+ # always counts as a "tcp attempt" for the state machine
+ self.resolver.nameservers = ["https://127.0.0.1:443/bogus"]
+ (_, _) = self.resn.next_request()
+ (nameserver, tcp_attempt, _) = self.resn.next_nameserver()
+ print(nameserver)
+ assert tcp_attempt
+
def test_query_result_nameserver_removing_exceptions(self):
# add some nameservers so we have enough to remove :)
- self.resolver.nameservers.extend(["10.0.0.3", "10.0.0.4"])
+ new_nameservers = list(self.resolver.nameservers[:])
+ new_nameservers.extend(["10.0.0.3", "10.0.0.4"])
+ self.resolver.nameservers = new_nameservers
(request, _) = self.resn.next_request()
exceptions = [
dns.exception.FormError(),
@@ -334,7 +339,7 @@ class ResolutionTestCase(unittest.TestCase):
dns.message.Truncated(),
]
for i in range(4):
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
if i == 3:
# Truncated is only bad if we're doing TCP, make it look
# like that's the case
@@ -351,7 +356,7 @@ class ResolutionTestCase(unittest.TestCase):
# test_query_result_nameserver_removing_exceptions(), we should
# not remove any nameservers and just continue resolving.
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(None, dns.exception.Timeout())
self.assertTrue(answer is None)
@@ -360,7 +365,7 @@ class ResolutionTestCase(unittest.TestCase):
def test_query_result_retry_with_tcp(self):
(request, _) = self.resn.next_request()
- (nameserver, _, tcp, _) = self.resn.next_nameserver()
+ (nameserver, tcp, _) = self.resn.next_nameserver()
self.assertFalse(tcp)
(answer, done) = self.resn.query_result(None, dns.message.Truncated())
self.assertTrue(answer is None)
@@ -374,7 +379,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
self.assertTrue(done)
@@ -386,7 +391,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
cache_answer = self.resolver.cache.get(
@@ -398,7 +403,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
@@ -409,7 +414,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q, True)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertTrue(done)
@@ -419,7 +424,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.NXDOMAIN)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNone(answer)
self.assertFalse(done)
@@ -429,7 +434,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_long_chain_response(q, 15)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNotNone(answer)
self.assertTrue(done)
@@ -438,7 +443,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_long_chain_response(q, 16)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNone(answer)
self.assertFalse(done)
@@ -449,7 +454,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q, True)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertTrue(done)
@@ -463,7 +468,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.YXDOMAIN)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
@@ -475,7 +480,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
@@ -487,7 +492,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
@@ -499,7 +504,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.REFUSED)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index d21127d..c1a97bf 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -27,6 +27,7 @@ from unittest.mock import patch
import dns.e164
import dns.message
import dns.name
+import dns.quic
import dns.rdataclass
import dns.rdatatype
import dns.resolver
@@ -717,6 +718,27 @@ class LiveResolverTests(unittest.TestCase):
answer2 = res.resolve("dns.google.", "A")
self.assertIs(answer2, answer1)
+ @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+ def testTLSNameserver(self):
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = [dns.nameserver.DoTNameserver("8.8.8.8", 853)]
+ answer = res.resolve("dns.google.", "A")
+ seen = set([rdata.address for rdata in answer])
+ self.assertIn("8.8.8.8", seen)
+ self.assertIn("8.8.4.4", seen)
+
+ @unittest.skipIf(
+ not (tests.util.have_ipv4() and dns.quic.have_quic),
+ "IPv4 not reachable or QUIC not available",
+ )
+ def testQuicNameserver(self):
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = [dns.nameserver.DoQNameserver("94.140.14.14", 784)]
+ answer = res.resolve("dns.adguard.com.", "A")
+ seen = set([rdata.address for rdata in answer])
+ self.assertIn("94.140.14.14", seen)
+ self.assertIn("94.140.15.15", seen)
+
def testCanonicalNameNoCNAME(self):
cname = dns.name.from_text("www.google.com")
self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname)
@@ -772,7 +794,6 @@ if hasattr(selectors, "PollSelector"):
class NXDOMAINExceptionTestCase(unittest.TestCase):
-
# pylint: disable=broad-except
def test_nxdomain_compatible(self):
@@ -951,6 +972,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
"1.2.3.4",
1234,
(1, 2, 3, 4),
+ (),
{"invalid": "nameserver"},
]
for invalid_nameserver in invalid_nameservers:
@@ -1123,7 +1145,7 @@ def testResolverTimeout():
errors = e.kwargs["errors"]
assert len(errors) > 1
for error in errors:
- assert error[0] == na.udp_address[0] # address
+ assert str(error[0]) == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
assert not error[1] # not TCP
assert error[2] == na.udp_address[1] # port
assert isinstance(error[3], dns.exception.Timeout) # exception
@@ -1145,7 +1167,7 @@ def testResolverNoNameservers():
errors = e.kwargs["errors"]
assert len(errors) == 1
for error in errors:
- assert error[0] == na.udp_address[0] # address
+ assert error[0] == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
assert not error[1] # not TCP
assert error[2] == na.udp_address[1] # port
assert error[3] == "FORMERR"