diff options
Diffstat (limited to 'dns/quic/_asyncio.py')
-rw-r--r-- | dns/quic/_asyncio.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index bcce048..80f244d 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -17,6 +17,7 @@ from dns.quic._common import ( AsyncQuicConnection, AsyncQuicManager, QUIC_MAX_DATAGRAM, + UnexpectedEOF, ) @@ -30,8 +31,8 @@ class AsyncioQuicStream(BaseQuicStream): await self._wake_up.wait() async def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) if self._buffer.have(amount): return self._expecting = amount @@ -106,6 +107,11 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer.notify_all() except Exception: pass + finally: + self._done = True + async with self._wake_timer: + self._wake_timer.notify_all() + self._handshake_complete.set() async def _wait_for_wake_timer(self): async with self._wake_timer: @@ -115,7 +121,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): await self._socket_created.wait() while not self._done: datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, address) in datagrams: + for datagram, address in datagrams: assert address == self._peer[0] await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() @@ -162,6 +168,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): async def make_stream(self): await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = AsyncioQuicStream(self, stream_id) self._streams[stream_id] = stream |