Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if handshake is completed before sending frame on wsproto shutdown #1737

Merged
merged 6 commits into from Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -82,7 +82,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 97.82
fail_under = 97.87
Kludex marked this conversation as resolved.
Show resolved Hide resolved
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
61 changes: 60 additions & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -528,7 +528,6 @@ async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
print("accepted")
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
Expand All @@ -551,6 +550,66 @@ async def app(scope, receive, send):
assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls):
send_accept_task = asyncio.Event()

async def app(scope, receive, send):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000"))
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()
Comment on lines +553 to +578
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test actually fails on the websockets implementation.

This fails before and after my changes, so it is actually not this PR who broke it.

There's a single small change on the websockets implementation on this PR, which actually handles the next test.



@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_send_close_on_server_shutdown(ws_protocol_cls, http_protocol_cls):
disconnect_message = {}

async def app(scope, receive, send):
nonlocal disconnect_message
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
disconnect_message = message
break

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000"))
await asyncio.sleep(0.1)
disconnect_message_before_shutdown = disconnect_message

assert disconnect_message_before_shutdown == {}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
task.cancel()


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
3 changes: 3 additions & 0 deletions tests/utils.py
@@ -1,5 +1,6 @@
import asyncio
import os
import traceback
Kludex marked this conversation as resolved.
Show resolved Hide resolved
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path

Expand All @@ -13,6 +14,8 @@ async def run_server(config: Config, sockets=None):
await asyncio.sleep(0.1)
try:
yield server
except BaseException: # pragma: no cover
traceback.print_exc()
Kludex marked this conversation as resolved.
Show resolved Hide resolved
finally:
await server.shutdown()
task.cancel()
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/protocols/websockets/websockets_impl.py
Expand Up @@ -345,6 +345,8 @@ async def asgi_receive(
data = await self.recv()
except ConnectionClosed as exc:
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
Comment on lines +348 to +349
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to match the behavior on both implementations.

If we close the connection on shutdown, we should be sending 1012.

return {"type": "websocket.disconnect", "code": exc.code}

msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item]
Expand Down
29 changes: 16 additions & 13 deletions uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -125,9 +125,12 @@ def resume_writing(self):
self.writable.set()

def shutdown(self):
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
Comment on lines +128 to +131
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core here.

We want to check if the handshake was completed, because CloseConnection is not a valid event for wsproto.WSConnection.send() when the handshake is not completed.

This is the error these lines solve:

wsproto.utilities.LocalProtocolError: Event CloseConnection(code=1012, reason=None) cannot be sent during the handshake

More about it on #596.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. 馃憤

else:
self.send_500_response()
Comment on lines +132 to +133
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is... If the above is False, we still need to send something to the client, otherwise we are being rude. 馃槥

Without these lines, the client will have a:

curl: (52) Empty reply from server

With these lines:

HTTP/1.1 500 
content-length: 0

The curl command used was:

curl --include \
     --no-buffer \
     --header "Connection: Upgrade" \
     --header "Upgrade: websocket" \
     --header "Host: example.com:80" \
     --header "Origin: http://example.com:80" \
     --header "Sec-WebSocket-Key: SGVsbG8sIHdvcmxkIQ==" \
     --header "Sec-WebSocket-Version: 13" \
     http://localhost:8000/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much neater. 馃槍

I did spend some time trying to figure out if we should we also include a textual description here, but I think it's probably okay as it currently stands.

self.transport.close()

def on_task_complete(self, task):
Expand Down Expand Up @@ -222,9 +225,8 @@ def send_500_response(self):
async def run_asgi(self):
try:
result = await self.app(self.scope, self.receive, self.send)
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
except BaseException:
self.logger.exception("Exception in ASGI application\n")
Comment on lines +228 to +229
Copy link
Sponsor Member Author

@Kludex Kludex Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are analogous. Just making it more readable.

if not self.handshake_complete:
self.send_500_response()
self.transport.close()
Expand Down Expand Up @@ -257,14 +259,15 @@ async def send(self, message):
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
if not self.transport.is_closing():
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
)
)
)
self.transport.write(output)
self.transport.write(output)
Comment on lines -260 to +270
Copy link
Sponsor Member Author

@Kludex Kludex Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary on this scenario, because the transport was closed on the shutdown(), it can also be that the client disconnected before accepting the connection...


elif message_type == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": None})
Expand Down