Skip to content

Commit

Permalink
Fix more races around long waits.
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Mar 23, 2024
1 parent 3238267 commit 0aa713d
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions dns/quic/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(self, connection, address, port, source, source_port, manager=None)
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._send_pending = False
self._check_for_events = False

async def _receiver(self):
try:
Expand All @@ -117,7 +119,10 @@ async def _receiver(self):
continue
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
# stuff to send now. We need to set a flag as well as wake up the
# timer to avoid a race where we get a datagram and generate an
# event right before the sender is going to sleep.
self._check_for_events = True
async with self._wake_timer:
self._wake_timer.notify_all()
except Exception:
Expand All @@ -135,16 +140,19 @@ async def _wait_for_wake_timer(self):
async def _sender(self):
await self._socket_created.wait()
while not self._done:
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
if not (self._check_for_events or self._send_pending):
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
self._check_for_events = False
await self._handle_events()

async def _handle_events(self):
Expand Down Expand Up @@ -194,6 +202,7 @@ async def _handle_events(self):

async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()

Expand Down

0 comments on commit 0aa713d

Please sign in to comment.