Skip to content

Commit

Permalink
Fix all codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Nov 1, 2022
1 parent 71b0541 commit 1a1a0b8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -553,7 +553,7 @@ async def app(scope, receive, send):
@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
async def test_connection_lost_before_handshake_complete(ws_protocol_cls, http_protocol_cls):
send_accept_task = asyncio.Event()
disconnect_message = {}

Expand Down
12 changes: 9 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.closed_event = asyncio.Event()
self.initial_response: Optional[HTTPResponse] = None
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

Expand Down Expand Up @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
Expand Down Expand Up @@ -335,11 +339,13 @@ async def asgi_receive(

await self.handshake_completed_event.wait()

if self.closed_event.is_set():
# If client disconnected, use WebSocketServerProtocol.close_code property.
if self.lost_connection_before_handshake:
# If the handshake failed or the app closed before handshake completion,
# use 1006 Abnormal Closure.
return {"type": "websocket.disconnect", "code": self.close_code or 1006}
return {"type": "websocket.disconnect", "code": 1006}

if self.closed_event.is_set():
return {"type": "websocket.disconnect", "code": 1005}

try:
data = await self.recv()
Expand Down
6 changes: 4 additions & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -70,13 +70,15 @@ def connection_made(self, transport):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005})
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_complete = True
if exc is None:
self.transport.close()

Expand Down Expand Up @@ -250,13 +252,13 @@ async def send(self, message):
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
self.handshake_complete = True
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
Expand Down

0 comments on commit 1a1a0b8

Please sign in to comment.