Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send disconnect event on connection lost for wsproto #996

Merged
merged 5 commits into from Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -82,7 +82,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 97.79
fail_under = 97.85
Kludex marked this conversation as resolved.
Show resolved Hide resolved
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
36 changes: 35 additions & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -10,6 +10,7 @@

try:
import websockets
import websockets.client
import websockets.exceptions
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory

Expand Down Expand Up @@ -64,7 +65,6 @@ def app(scope):
"connection": "upgrade",
"sec-webSocket-version": "11",
},
timeout=5,
)
if response.status_code == 426:
# response.text == ""
Expand Down Expand Up @@ -517,6 +517,40 @@ async def websocket_session(url):
await websocket_session("ws://127.0.0.1:8000")


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_client_connection_lost(ws_protocol_cls, http_protocol_cls):
got_disconnect_event = False

async def app(scope, receive, send):
nonlocal got_disconnect_event
while True:
message = await receive()
if message["type"] == "websocket.connect":
print("accepted")
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

got_disconnect_event = True

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=0.0,
)
async with run_server(config):
async with websockets.client.connect("ws://127.0.0.1:8000") as websocket:
websocket.transport.close()
await asyncio.sleep(0.1)
got_disconnect_event_before_shutdown = got_disconnect_event

assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
3 changes: 1 addition & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -70,8 +70,7 @@ def connection_made(self, transport):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
if exc is not None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
Expand Down