Skip to content

Commit

Permalink
Add test for connection lost before handshake is completed
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Oct 29, 2022
1 parent 0cc0e01 commit 608e38d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
29 changes: 28 additions & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -528,7 +528,6 @@ async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
print("accepted")
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
Expand All @@ -551,6 +550,34 @@ async def app(scope, receive, send):
assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
send_accept_task = asyncio.Event()

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000"))
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
3 changes: 3 additions & 0 deletions tests/utils.py
@@ -1,5 +1,6 @@
import asyncio
import os
import traceback
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path

Expand All @@ -13,6 +14,8 @@ async def run_server(config: Config, sockets=None):
await asyncio.sleep(0.1)
try:
yield server
except BaseException:
traceback.print_exc()
finally:
await server.shutdown()
task.cancel()
Expand Down

0 comments on commit 608e38d

Please sign in to comment.