From 7a5e59707b395454db2cb650371bbc2e800e7be4 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Wed, 8 Jul 2020 15:11:19 -0700 Subject: Add support for receiving UDP queries. The existing receive_udp() methods are only usable for receiving responses, as they require an expected destination and check that the message is from that destination. This change makes the expected destination (and hence the check) optional, and returns the address that the message was received from (in the sync case, this is only done if no destination is provided, for backwards compatibility). New tests are added, which required adding generic getsockname() support to the async backends. --- dns/_asyncio_backend.py | 6 ++++++ dns/_curio_backend.py | 6 ++++++ dns/_trio_backend.py | 9 +++++++++ dns/asyncquery.py | 36 ++++++++++++++++-------------------- dns/query.py | 44 ++++++++++++++++++++++++++++++++------------ tests/test_async.py | 19 +++++++++++++++++++ tests/test_query.py | 12 ++++++++++++ 7 files changed, 100 insertions(+), 32 deletions(-) diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index ba7c2e7..3af34ff 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.transport.get_extra_info('peername') + async def getsockname(self): + return self.transport.get_extra_info('sockname') + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, af, reader, writer): @@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.writer.get_extra_info('peername') + async def getsockname(self): + return self.writer.get_extra_info('sockname') + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index dca966d..300e1b8 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 0f1378f..92ea879 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, stream, tls=False): @@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): else: return self.stream.socket.getpeername() + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4afe7bc..b792648 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,8 +30,7 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \ - BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl # for brevity @@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination, expiration=None, +async def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None, *sock*, a ``dns.asyncbackend.DatagramSocket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received + message, the received time, and the address where the message arrived from. """ wire = b'' while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + return (r, received_time, from_address) async def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, @@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, stuple = _source_tuple(af, source, source_port) s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) await send_udp(s, wire, destination, expiration) - (r, received_time) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse diff --git a/dns/query.py b/dns/query.py index 13c8246..7df565d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource(f'got a response from {from_address} instead of ' + f'{destination}') + + def _destination_and_source(where, port, source, source_port, where_must_be_address=True): # Apply defaults and compute destination and source tuples @@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination, expiration=None, +def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None, *sock*, a ``socket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. """ wire = b'' while 1: _wait_for_readable(sock, expiration) (wire, from_address) = sock.recvfrom(65535) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, diff --git a/tests/test_async.py b/tests/test_async.py index 2d25434..5faaa6e 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -343,6 +343,25 @@ class AsyncTests(unittest.TestCase): (_, tcp) = self.async_run(run) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + async def run(): + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as listener: + listener_address = await listener.getsockname() + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as sender: + sender_address = await sender.getsockname() + q = dns.message.make_query('dns.google', dns.rdatatype.A) + await dns.asyncquery.send_udp(sender, q, listener_address) + expiration = time.time() + 2 + (_, _, recv_address) = await dns.asyncquery.receive_udp( + listener, expiration=expiration) + return (sender_address, recv_address) + (sender_address, recv_address) = self.async_run(run) + self.assertEqual(sender_address, recv_address) + def testUDPReceiveTimeout(self): async def arun(): async with await self.backend.make_socket(socket.AF_INET, diff --git a/tests/test_query.py b/tests/test_query.py index f1ec55c..498128d 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -191,6 +191,18 @@ class QueryTests(unittest.TestCase): (_, tcp) = dns.query.udp_with_fallback(q, address) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: + listener.bind(('127.0.0.1', 0)) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: + sender.bind(('127.0.0.1', 0)) + q = dns.message.make_query('dns.google', dns.rdatatype.A) + dns.query.send_udp(sender, q, listener.getsockname()) + expiration = time.time() + 2 + (q, _, addr) = dns.query.receive_udp(listener, + expiration=expiration) + self.assertEqual(addr, sender.getsockname()) + # for brevity _d_and_s = dns.query._destination_and_source -- cgit v1.2.1 From 5250399a9aeecab9dbf40a65164faf7290a08f5b Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Wed, 8 Jul 2020 15:29:18 -0700 Subject: Add coverage for TCP/TLS async getsockname. --- tests/test_async.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_async.py b/tests/test_async.py index 5faaa6e..db108c8 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -277,6 +277,8 @@ class AsyncTests(unittest.TestCase): socket.SOCK_STREAM, 0, None, (address, 53)) as s: + # for basic coverage + await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tcp(q, address, sock=s) response = self.async_run(run) @@ -315,6 +317,8 @@ class AsyncTests(unittest.TestCase): None, (address, 853), None, ssl_context, None) as s: + # for basic coverage + await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tls(q, '8.8.8.8', sock=s) response = self.async_run(run) -- cgit v1.2.1