summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2023-02-25 11:43:26 -0800
committerGitHub <noreply@github.com>2023-02-25 11:43:26 -0800
commitf7daeb87eac0a2727d5366cdff02fe08843678dd (patch)
treec38eab2e322ec02cde716a1aa9b70271a4594174 /tests
parentc76de3e1f2de416694353aa158545689f72cedca (diff)
downloaddnspython-f7daeb87eac0a2727d5366cdff02fe08843678dd.tar.gz
Resolver "nameserver" object support. (#897)
* Resolver "nameserver" object support. This turns the list of nameserver strings in the resolver into a tuple of nameserver objects, which abstract away making queries to a nameserver of a given type. The resolver's legacy nameserver list is "enriched" into a tuple of nameserver objects whenever it is set. Note that you cannot mutate the object other than by setting, e.g. res.nameservers.append("1.2.3.4") will not work. Error message accumulation has been updated to refer to the nameservers using a descriptive text form. * doco fix * more doco fixes * do enrichment at Resolution time * require a later mypy, fix type issues * add nameserver doc
Diffstat (limited to 'tests')
-rw-r--r--tests/test_resolution.py91
-rw-r--r--tests/test_resolver.py28
2 files changed, 73 insertions, 46 deletions
diff --git a/tests/test_resolution.py b/tests/test_resolution.py
index d2819a1..d8bdb2c 100644
--- a/tests/test_resolution.py
+++ b/tests/test_resolution.py
@@ -222,8 +222,8 @@ class ResolutionTestCase(unittest.TestCase):
def test_next_request_rotate(self):
self.resolver.rotate = True
- order1 = ["10.0.0.1", "10.0.0.2"]
- order2 = ["10.0.0.2", "10.0.0.1"]
+ order1 = ["Do53:10.0.0.1@53", "Do53:10.0.0.2@53"]
+ order2 = ["Do53:10.0.0.2@53", "Do53:10.0.0.1@53"]
seen1 = False
seen2 = False
# We're not interested in testing the randomness, but we'd
@@ -235,9 +235,11 @@ class ResolutionTestCase(unittest.TestCase):
self.resolver, self.qname, "A", "IN", False, True, False
)
self.resn.next_request()
- if self.resn.nameservers == order1:
+ text_form = [str(n) for n in self.resn.nameservers]
+ print(text_form)
+ if text_form == order1:
seen1 = True
- elif self.resn.nameservers == order2:
+ elif text_form == order2:
seen2 = True
else:
raise ValueError # should not happen!
@@ -264,68 +266,71 @@ class ResolutionTestCase(unittest.TestCase):
def test_next_nameserver_udp(self):
(request, answer) = self.resn.next_request()
- (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in self.resolver.nameservers)
- self.assertEqual(port, 53)
+ (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+ self.assertEqual(nameserver1.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver2 in self.resolver.nameservers)
+ (nameserver2, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver2 != nameserver1)
- self.assertEqual(port, 53)
+ self.assertEqual(nameserver2.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver3, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver3 is nameserver1)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.1)
- (nameserver4, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver4, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver4 is nameserver2)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver5, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver5, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver5 is nameserver1)
- self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.2)
def test_next_nameserver_retry_with_tcp(self):
(request, answer) = self.resn.next_request()
- (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in self.resolver.nameservers)
- self.assertEqual(port, 53)
+ (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+ self.assertEqual(nameserver1.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
self.resn.retry_with_tcp = True
- (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
+ (nameserver2, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver2 is nameserver1)
- self.assertEqual(port, 53)
self.assertTrue(tcp)
self.assertEqual(backoff, 0.0)
- (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver3 in self.resolver.nameservers)
+ (nameserver3, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver3 != nameserver1)
- self.assertEqual(port, 53)
+ self.assertEqual(nameserver3.port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
def test_next_nameserver_no_nameservers(self):
(request, answer) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.resn.nameservers.remove(nameserver)
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.resn.nameservers.remove(nameserver)
def bad():
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
self.assertRaises(dns.resolver.NoNameservers, bad)
+ def test_next_nameserver_max_size_nameserver(self):
+ # A query to a nameserver that always supports a maximum size query
+ # always counts as a "tcp attempt" for the state machine
+ self.resolver.nameservers = ["https://127.0.0.1:443/bogus"]
+ (_, _) = self.resn.next_request()
+ (nameserver, tcp_attempt, _) = self.resn.next_nameserver()
+ print(nameserver)
+ assert tcp_attempt
+
def test_query_result_nameserver_removing_exceptions(self):
# add some nameservers so we have enough to remove :)
- self.resolver.nameservers.extend(["10.0.0.3", "10.0.0.4"])
+ new_nameservers = list(self.resolver.nameservers[:])
+ new_nameservers.extend(["10.0.0.3", "10.0.0.4"])
+ self.resolver.nameservers = new_nameservers
(request, _) = self.resn.next_request()
exceptions = [
dns.exception.FormError(),
@@ -334,7 +339,7 @@ class ResolutionTestCase(unittest.TestCase):
dns.message.Truncated(),
]
for i in range(4):
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
if i == 3:
# Truncated is only bad if we're doing TCP, make it look
# like that's the case
@@ -351,7 +356,7 @@ class ResolutionTestCase(unittest.TestCase):
# test_query_result_nameserver_removing_exceptions(), we should
# not remove any nameservers and just continue resolving.
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(None, dns.exception.Timeout())
self.assertTrue(answer is None)
@@ -360,7 +365,7 @@ class ResolutionTestCase(unittest.TestCase):
def test_query_result_retry_with_tcp(self):
(request, _) = self.resn.next_request()
- (nameserver, _, tcp, _) = self.resn.next_nameserver()
+ (nameserver, tcp, _) = self.resn.next_nameserver()
self.assertFalse(tcp)
(answer, done) = self.resn.query_result(None, dns.message.Truncated())
self.assertTrue(answer is None)
@@ -374,7 +379,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
self.assertTrue(done)
@@ -386,7 +391,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
cache_answer = self.resolver.cache.get(
@@ -398,7 +403,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
@@ -409,7 +414,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q, True)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertTrue(done)
@@ -419,7 +424,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.NXDOMAIN)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNone(answer)
self.assertFalse(done)
@@ -429,7 +434,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_long_chain_response(q, 15)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNotNone(answer)
self.assertTrue(done)
@@ -438,7 +443,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_long_chain_response(q, 16)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertIsNone(answer)
self.assertFalse(done)
@@ -449,7 +454,7 @@ class ResolutionTestCase(unittest.TestCase):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q, True)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertTrue(done)
@@ -463,7 +468,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.YXDOMAIN)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
@@ -475,7 +480,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
@@ -487,7 +492,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
- (_, _, _, _) = self.resn.next_nameserver()
+ (_, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
@@ -499,7 +504,7 @@ class ResolutionTestCase(unittest.TestCase):
r = self.make_address_response(q)
r.set_rcode(dns.rcode.REFUSED)
(_, _) = self.resn.next_request()
- (nameserver, _, _, _) = self.resn.next_nameserver()
+ (nameserver, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index d21127d..c1a97bf 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -27,6 +27,7 @@ from unittest.mock import patch
import dns.e164
import dns.message
import dns.name
+import dns.quic
import dns.rdataclass
import dns.rdatatype
import dns.resolver
@@ -717,6 +718,27 @@ class LiveResolverTests(unittest.TestCase):
answer2 = res.resolve("dns.google.", "A")
self.assertIs(answer2, answer1)
+ @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+ def testTLSNameserver(self):
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = [dns.nameserver.DoTNameserver("8.8.8.8", 853)]
+ answer = res.resolve("dns.google.", "A")
+ seen = set([rdata.address for rdata in answer])
+ self.assertIn("8.8.8.8", seen)
+ self.assertIn("8.8.4.4", seen)
+
+ @unittest.skipIf(
+ not (tests.util.have_ipv4() and dns.quic.have_quic),
+ "IPv4 not reachable or QUIC not available",
+ )
+ def testQuicNameserver(self):
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = [dns.nameserver.DoQNameserver("94.140.14.14", 784)]
+ answer = res.resolve("dns.adguard.com.", "A")
+ seen = set([rdata.address for rdata in answer])
+ self.assertIn("94.140.14.14", seen)
+ self.assertIn("94.140.15.15", seen)
+
def testCanonicalNameNoCNAME(self):
cname = dns.name.from_text("www.google.com")
self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname)
@@ -772,7 +794,6 @@ if hasattr(selectors, "PollSelector"):
class NXDOMAINExceptionTestCase(unittest.TestCase):
-
# pylint: disable=broad-except
def test_nxdomain_compatible(self):
@@ -951,6 +972,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
"1.2.3.4",
1234,
(1, 2, 3, 4),
+ (),
{"invalid": "nameserver"},
]
for invalid_nameserver in invalid_nameservers:
@@ -1123,7 +1145,7 @@ def testResolverTimeout():
errors = e.kwargs["errors"]
assert len(errors) > 1
for error in errors:
- assert error[0] == na.udp_address[0] # address
+ assert str(error[0]) == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
assert not error[1] # not TCP
assert error[2] == na.udp_address[1] # port
assert isinstance(error[3], dns.exception.Timeout) # exception
@@ -1145,7 +1167,7 @@ def testResolverNoNameservers():
errors = e.kwargs["errors"]
assert len(errors) == 1
for error in errors:
- assert error[0] == na.udp_address[0] # address
+ assert error[0] == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
assert not error[1] # not TCP
assert error[2] == na.udp_address[1] # port
assert error[3] == "FORMERR"