diff --git a/docs/websockets.md b/docs/websockets.md index 807496188..43406aced 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -51,7 +51,7 @@ For example: `websocket.path_params['username']` ### Accepting the connection -* `await websocket.accept(subprotocol=None)` +* `await websocket.accept(subprotocol=None, headers=None)` ### Sending data diff --git a/starlette/testclient.py b/starlette/testclient.py index 40220fb4d..0b4bc78d1 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -298,6 +298,7 @@ def __init__( self.app = app self.scope = scope self.accepted_subprotocol = None + self.extra_headers = None self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() @@ -315,6 +316,7 @@ def __enter__(self) -> "WebSocketTestSession": self.exit_stack.close() raise self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) return self def __exit__(self, *args: typing.Any) -> None: diff --git a/starlette/websockets.py b/starlette/websockets.py index b9b8844d6..7632b28cf 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -69,11 +69,17 @@ async def send(self, message: Message) -> None: else: raise RuntimeError('Cannot call "send" once a close message has been sent.') - async def accept(self, subprotocol: str = None) -> None: + async def accept( + self, + subprotocol: str = None, + headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, + ) -> None: if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() - await self.send({"type": "websocket.accept", "subprotocol": subprotocol}) + await self.send( + {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} + ) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e02d433d5..bf0253309 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -301,6 +301,20 @@ async def asgi(receive, send): assert websocket.accepted_subprotocol == "wamp" +def test_additional_headers(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept(headers=[(b"additional", b"header")]) + await websocket.close() + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + assert websocket.extra_headers == [(b"additional", b"header")] + + def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send):