Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom dispatcher pings/timeouts #795

Merged
merged 11 commits into from
Feb 25, 2022
21 changes: 18 additions & 3 deletions websocket/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def select(self):
return r[0][0]


class WrappedDispatcher:
"""
WrappedDispatcher
"""
def __init__(self, app, ping_timeout, dispatcher):
self.app = app
self.ping_timeout = ping_timeout
self.dispatcher = dispatcher

def read(self, sock, read_callback, check_callback):
self.dispatcher.read(sock, read_callback)
self.ping_timeout and self.dispatcher.timeout(self.ping_timeout, check_callback)


class WebSocketApp:
"""
Higher level of APIs are provided. The interface is like JavaScript WebSocket object.
Expand Down Expand Up @@ -316,8 +330,7 @@ def teardown(close_frame=None):
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
dispatcher = self.create_dispatcher(ping_timeout, dispatcher)

self._callback(self.on_open)

Expand Down Expand Up @@ -375,7 +388,9 @@ def check():
teardown()
return not isinstance(e, KeyboardInterrupt)

def create_dispatcher(self, ping_timeout):
def create_dispatcher(self, ping_timeout, dispatcher=None):
if dispatcher:
return WrappedDispatcher(self, ping_timeout, dispatcher)
timeout = ping_timeout or 10
if self.sock.is_ssl():
return SSLDispatcher(self, timeout)
Expand Down
6 changes: 3 additions & 3 deletions websocket/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def _recv():
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
if error_code != errno.EAGAIN and error_code != errno.EWOULDBLOCK:
raise

sel = selectors.DefaultSelector()
Expand All @@ -111,6 +109,8 @@ def _recv():
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except TimeoutError:
raise WebSocketTimeoutException("Connection timed out")
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
Expand Down