diff options
author | Bob Halley <halley@dnspython.org> | 2023-02-25 11:40:08 -0800 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2023-02-25 11:40:08 -0800 |
commit | f893fcbb7c4438bdcc785b5d99ed3b5476513fa1 (patch) | |
tree | 3a6e060c9f0b0bddbeff4cf96e54a417c0c31418 | |
parent | c76de3e1f2de416694353aa158545689f72cedca (diff) | |
download | dnspython-quic-899.tar.gz |
Fix hangs when QUIC connection fails [#899].quic-899
This also fixes problems with computing the wait_for() timeout for
the sync and asyncio ports, and fixes delivery of the timeout for
the sync port.
-rw-r--r-- | dns/quic/_asyncio.py | 12 | ||||
-rw-r--r-- | dns/quic/_sync.py | 48 | ||||
-rw-r--r-- | dns/quic/_trio.py | 37 |
3 files changed, 62 insertions, 35 deletions
diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index bcce048..80f244d 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -17,6 +17,7 @@ from dns.quic._common import ( AsyncQuicConnection, AsyncQuicManager, QUIC_MAX_DATAGRAM, + UnexpectedEOF, ) @@ -30,8 +31,8 @@ class AsyncioQuicStream(BaseQuicStream): await self._wake_up.wait() async def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) if self._buffer.have(amount): return self._expecting = amount @@ -106,6 +107,11 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer.notify_all() except Exception: pass + finally: + self._done = True + async with self._wake_timer: + self._wake_timer.notify_all() + self._handshake_complete.set() async def _wait_for_wake_timer(self): async with self._wake_timer: @@ -115,7 +121,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): await self._socket_created.wait() while not self._done: datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, address) in datagrams: + for datagram, address in datagrams: assert address == self._peer[0] await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() @@ -162,6 +168,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): async def make_stream(self): await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = AsyncioQuicStream(self, stream_id) self._streams[stream_id] = stream diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index 8cc606a..bc034fa 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -17,6 +17,7 @@ from dns.quic._common import ( BaseQuicConnection, BaseQuicManager, QUIC_MAX_DATAGRAM, + UnexpectedEOF, ) # Avoid circularity with dns.query @@ -33,14 +34,15 @@ class SyncQuicStream(BaseQuicStream): self._lock = threading.Lock() def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) with self._lock: if self._buffer.have(amount): return self._expecting = amount with self._wake_up: - self._wake_up.wait(timeout) + if not self._wake_up.wait(timeout): + raise TimeoutError self._expecting = 0 def receive(self, timeout=None): @@ -114,24 +116,30 @@ class SyncQuicConnection(BaseQuicConnection): return def _worker(self): - sel = _selector_class() - sel.register(self._socket, selectors.EVENT_READ, self._read) - sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - items = sel.select(interval) - for (key, _) in items: - key.data() + try: + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for key, _ in items: + key.data() + with self._lock: + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + self._handle_events() + finally: with self._lock: - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - try: - self._socket.send(datagram) - except BlockingIOError: - # we let QUIC handle any lossage - pass - self._handle_events() + self._done = True + # Ensure anyone waiting for this gets woken up. + self._handshake_complete.set() def _handle_events(self): while True: @@ -166,6 +174,8 @@ class SyncQuicConnection(BaseQuicConnection): def make_stream(self): self._handshake_complete.wait() with self._lock: + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = SyncQuicStream(self, stream_id) self._streams[stream_id] = stream diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 543e3cb..7f81061 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -17,6 +17,7 @@ from dns.quic._common import ( AsyncQuicConnection, AsyncQuicManager, QUIC_MAX_DATAGRAM, + UnexpectedEOF, ) @@ -80,20 +81,26 @@ class TrioQuicConnection(AsyncQuicConnection): self._worker_scope = None async def _worker(self): - await self._socket.connect(self._peer) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - with trio.CancelScope( - deadline=trio.current_time() + interval - ) as self._worker_scope: - datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) - self._connection.receive_datagram(datagram, self._peer[0], time.time()) - self._worker_scope = None - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - await self._socket.send(datagram) - await self._handle_events() + try: + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram( + datagram, self._peer[0], time.time() + ) + self._worker_scope = None + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + await self._socket.send(datagram) + await self._handle_events() + finally: + self._done = True + self._handshake_complete.set() async def _handle_events(self): count = 0 @@ -132,6 +139,8 @@ class TrioQuicConnection(AsyncQuicConnection): async def make_stream(self): await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = TrioQuicStream(self, stream_id) self._streams[stream_id] = stream |