Skip to content

Commit

Permalink
Use correct WebSocket error codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Nov 1, 2022
1 parent ec3aac3 commit 71b0541
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 12 additions & 8 deletions tests/protocols/test_websocket.py
Expand Up @@ -555,15 +555,15 @@ async def app(scope, receive, send):
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
send_accept_task = asyncio.Event()
disconnect_message = {}

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
nonlocal disconnect_message
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

async def websocket_session(uri):
async with websockets.client.connect(uri):
Expand All @@ -577,6 +577,8 @@ async def websocket_session(uri):
task.cancel()
send_accept_task.set()

assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
Expand Down Expand Up @@ -729,14 +731,15 @@ async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls, http_protocol_cls
):
frames = []
client_close_connection = asyncio.Event()

class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})
# Ensure server doesn't start reading frames from read buffer until
# after client has sent close frame, but server is still able to
# read these frames
await asyncio.sleep(0.2)
await client_close_connection.wait()

async def websocket_receive(self, message):
frames.append(message.get("bytes"))
Expand All @@ -750,6 +753,7 @@ async def send_text(url):
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
await send_text("ws://127.0.0.1:8000")
client_close_connection.set()

assert frames == [b"abc", b"abc", b"abc"]

Expand Down
4 changes: 2 additions & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -70,7 +70,7 @@ def connection_made(self, transport):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
self.queue.put_nowait({"type": "websocket.disconnect"})
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
Expand Down Expand Up @@ -267,7 +267,7 @@ async def send(self, message):
self.transport.write(output)

elif message_type == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": None})
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
Expand Down

0 comments on commit 71b0541

Please sign in to comment.