Skip to content

Commit

Permalink
Only close socket in the main thread
Browse files Browse the repository at this point in the history
This solves a race condition that may exist when attempting to loop over
the open sockets and then calling select() and accidentally have called
close() on the socket in an app thread.
  • Loading branch information
digitalresistor committed May 25, 2022
1 parent 7c3739b commit c7a3d7e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
18 changes: 10 additions & 8 deletions src/waitress/channel.py
Expand Up @@ -126,10 +126,10 @@ def handle_write(self):
if self.will_close:
self.handle_close()

def _flush_exception(self, flush):
def _flush_exception(self, flush, do_close=True):
if flush:
try:
return (flush(), False)
return (flush(do_close=do_close), False)
except OSError:
if self.adj.log_socket_errors:
self.logger.exception("Socket error")
Expand Down Expand Up @@ -240,20 +240,20 @@ def received(self, data):

return True

def _flush_some_if_lockable(self):
def _flush_some_if_lockable(self, do_close=True):
# Since our task may be appending to the outbuf, we try to acquire
# the lock, but we don't block if we can't.

if self.outbuf_lock.acquire(False):
try:
self._flush_some()
self._flush_some(do_close=do_close)

if self.total_outbufs_len < self.adj.outbuf_high_watermark:
self.outbuf_lock.notify()
finally:
self.outbuf_lock.release()

def _flush_some(self):
def _flush_some(self, do_close=True):
# Send as much data as possible to our client

sent = 0
Expand All @@ -267,7 +267,7 @@ def _flush_some(self):

while outbuflen > 0:
chunk = outbuf.get(self.sendbuf_len)
num_sent = self.send(chunk)
num_sent = self.send(chunk, do_close=do_close)

if num_sent:
outbuf.skip(num_sent, True)
Expand Down Expand Up @@ -374,7 +374,9 @@ def write_soon(self, data):
self.total_outbufs_len += num_bytes

if self.total_outbufs_len >= self.adj.send_bytes:
(flushed, exception) = self._flush_exception(self._flush_some)
(flushed, exception) = self._flush_exception(
self._flush_some, do_close=False
)

if (
exception
Expand All @@ -392,7 +394,7 @@ def _flush_outbufs_below_high_watermark(self):

if self.total_outbufs_len > self.adj.outbuf_high_watermark:
with self.outbuf_lock:
(_, exception) = self._flush_exception(self._flush_some)
(_, exception) = self._flush_exception(self._flush_some, do_close=False)

if exception:
# An exception happened while flushing, wake up the main
Expand Down
5 changes: 3 additions & 2 deletions src/waitress/wasyncore.py
Expand Up @@ -426,15 +426,16 @@ def accept(self):
else:
return conn, addr

def send(self, data):
def send(self, data, do_close=True):
try:
result = self.socket.send(data)
return result
except OSError as why:
if why.args[0] == EWOULDBLOCK:
return 0
elif why.args[0] in _DISCONNECTED:
self.handle_close()
if do_close:
self.handle_close()
return 0
else:
raise
Expand Down
4 changes: 2 additions & 2 deletions tests/test_channel.py
Expand Up @@ -376,7 +376,7 @@ def test_handle_write_no_notify_after_flush(self):
inst.total_outbufs_len = len(inst.outbufs[0])
inst.adj.send_bytes = 1
inst.adj.outbuf_high_watermark = 2
sock.send = lambda x: False
sock.send = lambda x, do_close=True: False
inst.will_close = False
inst.last_activity = 0
result = inst.handle_write()
Expand Down Expand Up @@ -453,7 +453,7 @@ def get(self, numbytes):

buf = DummyHugeOutbuffer()
inst.outbufs = [buf]
inst.send = lambda *arg: 0
inst.send = lambda *arg, do_close: 0
result = inst._flush_some()
# we are testing that _flush_some doesn't raise an OverflowError
# when one of its outbufs has a __len__ that returns gt sys.maxint
Expand Down

0 comments on commit c7a3d7e

Please sign in to comment.