Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 26, 2022
1 parent 782f396 commit ed7d669
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 98.50
fail_under = 98.70
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
36 changes: 30 additions & 6 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import typing

import httpx
import pytest
Expand Down Expand Up @@ -713,11 +714,23 @@ async def app(scope, receive, send):
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

response: typing.Optional[httpx.Response] = None

async def websocket_session(uri):
await websockets.client.connect(uri)
# await websockets.client.connect(uri)
nonlocal response
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)

config = Config(
app=app,
Expand All @@ -731,9 +744,13 @@ async def websocket_session(uri):
websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
)
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()

task.cancel()

assert response is not None
assert response.status_code == 500, response.text
assert response.text == "Internal Server Error"
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}


Expand All @@ -744,6 +761,7 @@ async def test_send_close_on_server_shutdown(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnect_message = {}
server_shutdown_event = asyncio.Event()

async def app(scope, receive, send):
nonlocal disconnect_message
Expand All @@ -755,10 +773,13 @@ async def app(scope, receive, send):
disconnect_message = message
break

websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)
nonlocal websocket
async with websockets.client.connect(uri) as ws_connection:
websocket = ws_connection
await server_shutdown_event.wait()

config = Config(
app=app,
Expand All @@ -773,7 +794,10 @@ async def websocket_session(uri):
)
await asyncio.sleep(0.1)
disconnect_message_before_shutdown = disconnect_message
server_shutdown_event.set()

assert websocket is not None
assert websocket.close_code == 1012
assert disconnect_message_before_shutdown == {}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
task.cancel()
Expand Down
5 changes: 4 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def connection_lost(self, exc: Optional[Exception]) -> None:

def shutdown(self) -> None:
self.ws_server.closing = True
if not self.transport.is_closing():
if self.handshake_completed_event.is_set():
self.fail_connection(1012)
else:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
self.tasks.discard(task)
Expand Down
7 changes: 6 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,12 @@ def send_500_response(self) -> None:
(b"connection", b"close"),
]
if self.conn.connection is None:
output = self.conn.send(wsproto.events.RejectConnection(status_code=500))
output = self.conn.send(
wsproto.events.RejectConnection(status_code=500, has_body=True)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
else:
msg = h11.Response(
status_code=500, headers=headers, reason="Internal Server Error"
Expand Down

0 comments on commit ed7d669

Please sign in to comment.