diff options
| author | Bob Halley <halley@dnspython.org> | 2023-02-25 11:43:26 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-25 11:43:26 -0800 |
| commit | f7daeb87eac0a2727d5366cdff02fe08843678dd (patch) | |
| tree | c38eab2e322ec02cde716a1aa9b70271a4594174 /tests | |
| parent | c76de3e1f2de416694353aa158545689f72cedca (diff) | |
| download | dnspython-f7daeb87eac0a2727d5366cdff02fe08843678dd.tar.gz | |
Resolver "nameserver" object support. (#897)
* Resolver "nameserver" object support.
This turns the list of nameserver strings in the resolver into a tuple
of nameserver objects, which abstract away making queries to a
nameserver of a given type.
The resolver's legacy nameserver list is "enriched" into a tuple of
nameserver objects whenever it is set. Note that you cannot mutate
the object other than by setting,
e.g. res.nameservers.append("1.2.3.4") will not work.
Error message accumulation has been updated to refer to the
nameservers using a descriptive text form.
* doco fix
* more doco fixes
* do enrichment at Resolution time
* require a later mypy, fix type issues
* add nameserver doc
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_resolution.py | 91 | ||||
| -rw-r--r-- | tests/test_resolver.py | 28 |
2 files changed, 73 insertions, 46 deletions
diff --git a/tests/test_resolution.py b/tests/test_resolution.py index d2819a1..d8bdb2c 100644 --- a/tests/test_resolution.py +++ b/tests/test_resolution.py @@ -222,8 +222,8 @@ class ResolutionTestCase(unittest.TestCase): def test_next_request_rotate(self): self.resolver.rotate = True - order1 = ["10.0.0.1", "10.0.0.2"] - order2 = ["10.0.0.2", "10.0.0.1"] + order1 = ["Do53:10.0.0.1@53", "Do53:10.0.0.2@53"] + order2 = ["Do53:10.0.0.2@53", "Do53:10.0.0.1@53"] seen1 = False seen2 = False # We're not interested in testing the randomness, but we'd @@ -235,9 +235,11 @@ class ResolutionTestCase(unittest.TestCase): self.resolver, self.qname, "A", "IN", False, True, False ) self.resn.next_request() - if self.resn.nameservers == order1: + text_form = [str(n) for n in self.resn.nameservers] + print(text_form) + if text_form == order1: seen1 = True - elif self.resn.nameservers == order2: + elif text_form == order2: seen2 = True else: raise ValueError # should not happen! @@ -264,68 +266,71 @@ class ResolutionTestCase(unittest.TestCase): def test_next_nameserver_udp(self): (request, answer) = self.resn.next_request() - (nameserver1, port, tcp, backoff) = self.resn.next_nameserver() - self.assertTrue(nameserver1 in self.resolver.nameservers) - self.assertEqual(port, 53) + (nameserver1, tcp, backoff) = self.resn.next_nameserver() + self.assertEqual(nameserver1.port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.0) - (nameserver2, port, tcp, backoff) = self.resn.next_nameserver() - self.assertTrue(nameserver2 in self.resolver.nameservers) + (nameserver2, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver2 != nameserver1) - self.assertEqual(port, 53) + self.assertEqual(nameserver2.port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.0) - (nameserver3, port, tcp, backoff) = self.resn.next_nameserver() + (nameserver3, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver3 is nameserver1) - self.assertEqual(port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.1) - (nameserver4, port, tcp, backoff) = self.resn.next_nameserver() + (nameserver4, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver4 is nameserver2) - self.assertEqual(port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.0) - (nameserver5, port, tcp, backoff) = self.resn.next_nameserver() + (nameserver5, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver5 is nameserver1) - self.assertEqual(port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.2) def test_next_nameserver_retry_with_tcp(self): (request, answer) = self.resn.next_request() - (nameserver1, port, tcp, backoff) = self.resn.next_nameserver() - self.assertTrue(nameserver1 in self.resolver.nameservers) - self.assertEqual(port, 53) + (nameserver1, tcp, backoff) = self.resn.next_nameserver() + self.assertEqual(nameserver1.port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.0) self.resn.retry_with_tcp = True - (nameserver2, port, tcp, backoff) = self.resn.next_nameserver() + (nameserver2, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver2 is nameserver1) - self.assertEqual(port, 53) self.assertTrue(tcp) self.assertEqual(backoff, 0.0) - (nameserver3, port, tcp, backoff) = self.resn.next_nameserver() - self.assertTrue(nameserver3 in self.resolver.nameservers) + (nameserver3, tcp, backoff) = self.resn.next_nameserver() self.assertTrue(nameserver3 != nameserver1) - self.assertEqual(port, 53) + self.assertEqual(nameserver3.port, 53) self.assertFalse(tcp) self.assertEqual(backoff, 0.0) def test_next_nameserver_no_nameservers(self): (request, answer) = self.resn.next_request() - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() self.resn.nameservers.remove(nameserver) - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() self.resn.nameservers.remove(nameserver) def bad(): - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() self.assertRaises(dns.resolver.NoNameservers, bad) + def test_next_nameserver_max_size_nameserver(self): + # A query to a nameserver that always supports a maximum size query + # always counts as a "tcp attempt" for the state machine + self.resolver.nameservers = ["https://127.0.0.1:443/bogus"] + (_, _) = self.resn.next_request() + (nameserver, tcp_attempt, _) = self.resn.next_nameserver() + print(nameserver) + assert tcp_attempt + def test_query_result_nameserver_removing_exceptions(self): # add some nameservers so we have enough to remove :) - self.resolver.nameservers.extend(["10.0.0.3", "10.0.0.4"]) + new_nameservers = list(self.resolver.nameservers[:]) + new_nameservers.extend(["10.0.0.3", "10.0.0.4"]) + self.resolver.nameservers = new_nameservers (request, _) = self.resn.next_request() exceptions = [ dns.exception.FormError(), @@ -334,7 +339,7 @@ class ResolutionTestCase(unittest.TestCase): dns.message.Truncated(), ] for i in range(4): - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() if i == 3: # Truncated is only bad if we're doing TCP, make it look # like that's the case @@ -351,7 +356,7 @@ class ResolutionTestCase(unittest.TestCase): # test_query_result_nameserver_removing_exceptions(), we should # not remove any nameservers and just continue resolving. (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() nameservers = self.resn.nameservers[:] (answer, done) = self.resn.query_result(None, dns.exception.Timeout()) self.assertTrue(answer is None) @@ -360,7 +365,7 @@ class ResolutionTestCase(unittest.TestCase): def test_query_result_retry_with_tcp(self): (request, _) = self.resn.next_request() - (nameserver, _, tcp, _) = self.resn.next_nameserver() + (nameserver, tcp, _) = self.resn.next_nameserver() self.assertFalse(tcp) (answer, done) = self.resn.query_result(None, dns.message.Truncated()) self.assertTrue(answer is None) @@ -374,7 +379,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_address_response(q) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertFalse(answer is None) self.assertTrue(done) @@ -386,7 +391,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_address_response(q) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertFalse(answer is None) cache_answer = self.resolver.cache.get( @@ -398,7 +403,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_negative_response(q) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() def bad(): (answer, done) = self.resn.query_result(r, None) @@ -409,7 +414,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_negative_response(q, True) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) self.assertTrue(done) @@ -419,7 +424,7 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_address_response(q) r.set_rcode(dns.rcode.NXDOMAIN) (_, _) = self.resn.next_request() - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertIsNone(answer) self.assertFalse(done) @@ -429,7 +434,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_long_chain_response(q, 15) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertIsNotNone(answer) self.assertTrue(done) @@ -438,7 +443,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_long_chain_response(q, 16) (_, _) = self.resn.next_request() - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertIsNone(answer) self.assertFalse(done) @@ -449,7 +454,7 @@ class ResolutionTestCase(unittest.TestCase): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_negative_response(q, True) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) self.assertTrue(done) @@ -463,7 +468,7 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_address_response(q) r.set_rcode(dns.rcode.YXDOMAIN) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() def bad(): (answer, done) = self.resn.query_result(r, None) @@ -475,7 +480,7 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_address_response(q) r.set_rcode(dns.rcode.SERVFAIL) (_, _) = self.resn.next_request() - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) self.assertFalse(done) @@ -487,7 +492,7 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_address_response(q) r.set_rcode(dns.rcode.SERVFAIL) (_, _) = self.resn.next_request() - (_, _, _, _) = self.resn.next_nameserver() + (_, _, _) = self.resn.next_nameserver() nameservers = self.resn.nameservers[:] (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) @@ -499,7 +504,7 @@ class ResolutionTestCase(unittest.TestCase): r = self.make_address_response(q) r.set_rcode(dns.rcode.REFUSED) (_, _) = self.resn.next_request() - (nameserver, _, _, _) = self.resn.next_nameserver() + (nameserver, _, _) = self.resn.next_nameserver() (answer, done) = self.resn.query_result(r, None) self.assertTrue(answer is None) self.assertFalse(done) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index d21127d..c1a97bf 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -27,6 +27,7 @@ from unittest.mock import patch import dns.e164 import dns.message import dns.name +import dns.quic import dns.rdataclass import dns.rdatatype import dns.resolver @@ -717,6 +718,27 @@ class LiveResolverTests(unittest.TestCase): answer2 = res.resolve("dns.google.", "A") self.assertIs(answer2, answer1) + @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable") + def testTLSNameserver(self): + res = dns.resolver.Resolver(configure=False) + res.nameservers = [dns.nameserver.DoTNameserver("8.8.8.8", 853)] + answer = res.resolve("dns.google.", "A") + seen = set([rdata.address for rdata in answer]) + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + + @unittest.skipIf( + not (tests.util.have_ipv4() and dns.quic.have_quic), + "IPv4 not reachable or QUIC not available", + ) + def testQuicNameserver(self): + res = dns.resolver.Resolver(configure=False) + res.nameservers = [dns.nameserver.DoQNameserver("94.140.14.14", 784)] + answer = res.resolve("dns.adguard.com.", "A") + seen = set([rdata.address for rdata in answer]) + self.assertIn("94.140.14.14", seen) + self.assertIn("94.140.15.15", seen) + def testCanonicalNameNoCNAME(self): cname = dns.name.from_text("www.google.com") self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname) @@ -772,7 +794,6 @@ if hasattr(selectors, "PollSelector"): class NXDOMAINExceptionTestCase(unittest.TestCase): - # pylint: disable=broad-except def test_nxdomain_compatible(self): @@ -951,6 +972,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): "1.2.3.4", 1234, (1, 2, 3, 4), + (), {"invalid": "nameserver"}, ] for invalid_nameserver in invalid_nameservers: @@ -1123,7 +1145,7 @@ def testResolverTimeout(): errors = e.kwargs["errors"] assert len(errors) > 1 for error in errors: - assert error[0] == na.udp_address[0] # address + assert str(error[0]) == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}" assert not error[1] # not TCP assert error[2] == na.udp_address[1] # port assert isinstance(error[3], dns.exception.Timeout) # exception @@ -1145,7 +1167,7 @@ def testResolverNoNameservers(): errors = e.kwargs["errors"] assert len(errors) == 1 for error in errors: - assert error[0] == na.udp_address[0] # address + assert error[0] == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}" assert not error[1] # not TCP assert error[2] == na.udp_address[1] # port assert error[3] == "FORMERR" |
