summaryrefslogtreecommitdiff
path: root/dns/quic/_asyncio.py
diff options
context:
space:
mode:
Diffstat (limited to 'dns/quic/_asyncio.py')
-rw-r--r--dns/quic/_asyncio.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py
index bcce048..80f244d 100644
--- a/dns/quic/_asyncio.py
+++ b/dns/quic/_asyncio.py
@@ -17,6 +17,7 @@ from dns.quic._common import (
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
+ UnexpectedEOF,
)
@@ -30,8 +31,8 @@ class AsyncioQuicStream(BaseQuicStream):
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
- timeout = self._timeout_from_expiration(expiration)
while True:
+ timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount):
return
self._expecting = amount
@@ -106,6 +107,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer.notify_all()
except Exception:
pass
+ finally:
+ self._done = True
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ self._handshake_complete.set()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
@@ -115,7 +121,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
- for (datagram, address) in datagrams:
+ for datagram, address in datagrams:
assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
@@ -162,6 +168,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
async def make_stream(self):
await self._handshake_complete.wait()
+ if self._done:
+ raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream