diff --git a/docs/settings.md b/docs/settings.md index 14b2882792..9d9ff3698b 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -90,7 +90,7 @@ connecting IPs in the `forwarded-allow-ips` configuration. * `--date-header` / `--no-date-header` - Enable/Disable default `Date` header. !!! note - The `--no-server-header` flag doesn't have effect on the WebSockets implementations. + The `--no-date-header` flag doesn't have effect on the `websockets` implementation. ## HTTPS diff --git a/pyproject.toml b/pyproject.toml index f8791dd97e..13a02b7b6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ standard = [ "PyYAML>=5.1", "uvloop>=0.14.0,!=0.15.0,!=0.15.1; sys_platform != 'win32' and (sys_platform != 'cygwin' and platform_python_implementation != 'PyPy')", "watchfiles>=0.13", - "websockets>=10.0", + "websockets>=10.4", ] [project.scripts] diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 9048755566..7b0619cae2 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -10,6 +10,7 @@ try: import websockets + import websockets.exceptions from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol @@ -18,8 +19,8 @@ WebSocketProtocol = None ClientPerMessageDeflateFactory = None - ONLY_WEBSOCKETPROTOCOL = [p for p in [WebSocketProtocol] if p is not None] +ONLY_WS_PROTOCOL = [p for p in [WSProtocol] if p is not None] WS_PROTOCOLS = [p for p in [WSProtocol, WebSocketProtocol] if p is not None] pytestmark = pytest.mark.skipif( websockets is None, reason="This test needs the websockets module" @@ -658,3 +659,100 @@ async def send_text(url): await send_text("ws://127.0.0.1:8000") assert frames == [b"abc", b"abc", b"abc"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_default_server_headers(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.response_headers + + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + headers = await open_connection("ws://127.0.0.1:8000") + assert headers.get("server") == "uvicorn" and "date" in headers + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_no_server_headers(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.response_headers + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + server_header=False, + ) + async with run_server(config): + headers = await open_connection("ws://127.0.0.1:8000") + assert "server" not in headers + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", ONLY_WS_PROTOCOL) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_no_date_header(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.response_headers + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + date_header=False, + ) + async with run_server(config): + headers = await open_connection("ws://127.0.0.1:8000") + assert "date" not in headers + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_multiple_server_header(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send( + { + "type": "websocket.accept", + "headers": [ + (b"Server", b"over-ridden"), + (b"Server", b"another-value"), + ], + } + ) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.response_headers + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + ) + async with run_server(config): + headers = await open_connection("ws://127.0.0.1:8000") + assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index b77485136b..01133b7e29 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -23,9 +23,9 @@ ) from uvicorn.server import ServerState -if sys.version_info < (3, 8): +if sys.version_info < (3, 8): # pragma: py-gte-38 from typing_extensions import Literal -else: +else: # pragma: py-lt-38 from typing import Literal if TYPE_CHECKING: @@ -103,9 +103,13 @@ def __init__( max_size=self.config.ws_max_size, ping_interval=self.config.ws_ping_interval, ping_timeout=self.config.ws_ping_timeout, + server_header=None, extensions=extensions, logger=logging.getLogger("uvicorn.error"), - extra_headers=[], + extra_headers=[ + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in server_state.default_headers + ], ) def connection_made( # type: ignore[override] diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 8170b15164..55ba386dfb 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -32,6 +32,7 @@ def __init__(self, config, server_state, _loop=None): # Shared server state self.connections = server_state.connections self.tasks = server_state.tasks + self.default_headers = server_state.default_headers # Connection state self.transport = None @@ -255,7 +256,7 @@ async def send(self, message): ) self.handshake_complete = True subprotocol = message.get("subprotocol") - extra_headers = message.get("headers", []) + extra_headers = self.default_headers + list(message.get("headers", [])) extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate())