From 128e35f45dc9baff9096d6a3b05645e66d7985eb Mon Sep 17 00:00:00 2001 From: matiuszka Date: Wed, 15 Dec 2021 14:01:17 +0100 Subject: [PATCH 1/4] Additional headers for WS accept message. --- docs/websockets.md | 3 +-- starlette/testclient.py | 2 ++ starlette/websockets.py | 10 ++++++++-- tests/test_websockets.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/docs/websockets.md b/docs/websockets.md index 807496188..5eedca3c4 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -1,4 +1,3 @@ - Starlette includes a `WebSocket` class that fulfils a similar role to the HTTP request, but that allows sending and receiving data on a websocket. @@ -51,7 +50,7 @@ For example: `websocket.path_params['username']` ### Accepting the connection -* `await websocket.accept(subprotocol=None)` +* `await websocket.accept(subprotocol=None, headers=None)` ### Sending data diff --git a/starlette/testclient.py b/starlette/testclient.py index 08d03fa5c..9137dbbf4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -296,6 +296,7 @@ def __init__( self.app = app self.scope = scope self.accepted_subprotocol = None + self.additional_headers = None self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() @@ -313,6 +314,7 @@ def __enter__(self) -> "WebSocketTestSession": self.exit_stack.close() raise self.accepted_subprotocol = message.get("subprotocol", None) + self.additional_headers = message.get("headers", None) return self def __exit__(self, *args: typing.Any) -> None: diff --git a/starlette/websockets.py b/starlette/websockets.py index b9b8844d6..7632b28cf 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -69,11 +69,17 @@ async def send(self, message: Message) -> None: else: raise RuntimeError('Cannot call "send" once a close message has been sent.') - async def accept(self, subprotocol: str = None) -> None: + async def accept( + self, + subprotocol: str = None, + headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, + ) -> None: if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() - await self.send({"type": "websocket.accept", "subprotocol": subprotocol}) + await self.send( + {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} + ) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e02d433d5..cf44e565c 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -301,6 +301,20 @@ async def asgi(receive, send): assert websocket.accepted_subprotocol == "wamp" +def test_additional_headers(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept(headers=[(b"additional", b"header")]) + await websocket.close() + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + websocket.additional_headers = [(b"additional", b"header")] + + def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send): From fa5bbfa128b4b2a0f7b57111239c7f35a2aa2b47 Mon Sep 17 00:00:00 2001 From: matiuszka <40184215+matiuszka@users.noreply.github.com> Date: Wed, 22 Dec 2021 14:24:42 +0100 Subject: [PATCH 2/4] Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski --- tests/test_websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index cf44e565c..01e2d6423 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -312,7 +312,7 @@ async def asgi(receive, send): client = test_client_factory(app) with client.websocket_connect("/") as websocket: - websocket.additional_headers = [(b"additional", b"header")] + assert websocket.additional_headers == [(b"additional", b"header")] def test_websocket_exception(test_client_factory): From 9aa2866d92b09c831a7d46fb2255527b31ebe4e2 Mon Sep 17 00:00:00 2001 From: matiuszka Date: Tue, 4 Jan 2022 08:43:12 +0100 Subject: [PATCH 3/4] fixup! Additional headers for WS accept message. --- docs/websockets.md | 1 + starlette/testclient.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/websockets.md b/docs/websockets.md index 5eedca3c4..43406aced 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -1,3 +1,4 @@ + Starlette includes a `WebSocket` class that fulfils a similar role to the HTTP request, but that allows sending and receiving data on a websocket. diff --git a/starlette/testclient.py b/starlette/testclient.py index bb98990b3..0b4bc78d1 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -298,7 +298,7 @@ def __init__( self.app = app self.scope = scope self.accepted_subprotocol = None - self.additional_headers = None + self.extra_headers = None self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() @@ -316,7 +316,7 @@ def __enter__(self) -> "WebSocketTestSession": self.exit_stack.close() raise self.accepted_subprotocol = message.get("subprotocol", None) - self.additional_headers = message.get("headers", None) + self.extra_headers = message.get("headers", None) return self def __exit__(self, *args: typing.Any) -> None: From 8cc23b8a5454d42ad79de6602c4403120148c17c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 4 Jan 2022 10:37:31 +0100 Subject: [PATCH 4/4] Update tests/test_websockets.py --- tests/test_websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 01e2d6423..bf0253309 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -312,7 +312,7 @@ async def asgi(receive, send): client = test_client_factory(app) with client.websocket_connect("/") as websocket: - assert websocket.additional_headers == [(b"additional", b"header")] + assert websocket.extra_headers == [(b"additional", b"header")] def test_websocket_exception(test_client_factory):