From 6a7e2406631898e241d5f64d951c94c6b46d212f Mon Sep 17 00:00:00 2001 From: Argannor <4489279+Argannor@users.noreply.github.com> Date: Wed, 19 Oct 2022 08:17:48 +0200 Subject: [PATCH] Upgrade request handling: ignore http/2 and optionally ignore websocket (#1661) --- docs/settings.md | 2 +- tests/protocols/test_http.py | 65 +++++++++++++++++++++--- uvicorn/config.py | 3 -- uvicorn/protocols/http/h11_impl.py | 52 ++++++++++--------- uvicorn/protocols/http/httptools_impl.py | 60 ++++++++++++++-------- 5 files changed, 125 insertions(+), 57 deletions(-) diff --git a/docs/settings.md b/docs/settings.md index c40905f45a..14b2882792 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -67,7 +67,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other * `--loop ` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*. * `--http ` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*. -* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to deny all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*. +* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*. * `--ws-max-size ` - Set the WebSockets max message size, in bytes. Please note that this can be used only with the default `websockets` protocol. * `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Please note that this can be used only with the default `websockets` protocol. * `--ws-ping-timeout ` - Set the WebSockets ping timeout, in seconds. Please note that this can be used only with the default `websockets` protocol. diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 839b09441e..def8a3a88a 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -7,7 +7,7 @@ from tests.response import Response from uvicorn import Server -from uvicorn.config import Config +from uvicorn.config import WS_PROTOCOLS, Config from uvicorn.main import ServerState from uvicorn.protocols.http.h11_impl import H11Protocol @@ -18,6 +18,7 @@ HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None] +WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys() SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""]) @@ -76,6 +77,18 @@ ] ) +UPGRADE_HTTP2_REQUEST = b"\r\n".join( + [ + b"GET / HTTP/1.1", + b"Host: example.org", + b"Connection: upgrade", + b"Upgrade: h2c", + b"Sec-WebSocket-Version: 11", + b"", + b"", + ] +) + INVALID_REQUEST_TEMPLATE = b"\r\n".join( [ b"%s", @@ -697,23 +710,61 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls): @pytest.mark.anyio @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_unsupported_upgrade_request(protocol_cls): +async def test_supported_upgrade_request(protocol_cls): + app = Response("Hello, world", media_type="text/plain") + + protocol = get_connected_protocol(app, protocol_cls, ws="wsproto") + protocol.data_received(UPGRADE_REQUEST) + assert b"HTTP/1.1 426 " in protocol.transport.buffer + + +@pytest.mark.anyio +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +async def test_unsupported_ws_upgrade_request(protocol_cls): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, protocol_cls, ws="none") protocol.data_received(UPGRADE_REQUEST) - assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer - assert b"Unsupported upgrade request." in protocol.transport.buffer + await protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer @pytest.mark.anyio @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_supported_upgrade_request(protocol_cls): +async def test_unsupported_ws_upgrade_request_warn_on_auto( + caplog: pytest.LogCaptureFixture, protocol_cls +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="wsproto") + protocol = get_connected_protocol(app, protocol_cls, ws="auto") + protocol.ws_protocol_class = None protocol.data_received(UPGRADE_REQUEST) - assert b"HTTP/1.1 426 " in protocol.transport.buffer + await protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer + warnings = [ + record.msg + for record in filter( + lambda record: record.levelname == "WARNING", caplog.records + ) + ] + assert "Unsupported upgrade request." in warnings + msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 + assert msg in warnings + + +@pytest.mark.anyio +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +@pytest.mark.parametrize("ws", WEBSOCKET_PROTOCOLS) +async def test_http2_upgrade_request(protocol_cls, ws): + app = Response("Hello, world", media_type="text/plain") + + protocol = get_connected_protocol(app, protocol_cls, ws=ws) + protocol.data_received(UPGRADE_HTTP2_REQUEST) + await protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hello, world" in protocol.transport.buffer async def asgi3app(scope, receive, send): diff --git a/uvicorn/config.py b/uvicorn/config.py index 589d72d375..df91a6c4d1 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -87,10 +87,8 @@ } INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"] - SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER - LOGGING_CONFIG: Dict[str, Any] = { "version": 1, "disable_existing_loggers": False, @@ -159,7 +157,6 @@ def is_dir(path: Path) -> bool: def resolve_reload_patterns( patterns_list: List[str], directories_list: List[str] ) -> Tuple[List[str], List[Path]]: - directories: List[Path] = list(set(map(Path, directories_list.copy()))) patterns: List[str] = patterns_list.copy() diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 5fff70bb7d..b7a1dca425 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -155,6 +155,28 @@ def _unset_keepalive_if_required(self) -> None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None + def _get_upgrade(self) -> Optional[bytes]: + connection = [] + upgrade = None + for name, value in self.headers: + if name == b"connection": + connection = [token.lower().strip() for token in value.split(b",")] + if name == b"upgrade": + upgrade = value.lower() + if b"upgrade" in connection: + return upgrade + return None + + def _should_upgrade_to_ws(self) -> bool: + if self.ws_protocol_class is None: + if self.config.ws == "auto": + msg = "Unsupported upgrade request." + self.logger.warning(msg) + msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + return False + return True + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() @@ -204,12 +226,10 @@ def handle_events(self) -> None: "headers": self.headers, } - for name, value in self.headers: - if name == b"connection": - tokens = [token.lower().strip() for token in value.split(b",")] - if b"upgrade" in tokens: - self.handle_upgrade(event) - return + upgrade = self._get_upgrade() + if upgrade == b"websocket" and self._should_upgrade_to_ws(): + self.handle_websocket_upgrade(event) + return # Handle 503 responses when 'limit_concurrency' is exceeded. if self.limit_concurrency is not None and ( @@ -254,23 +274,7 @@ def handle_events(self) -> None: self.cycle.more_body = False self.cycle.message_event.set() - def handle_upgrade(self, event: H11Event) -> None: - upgrade_value = None - for name, value in self.headers: - if name == b"upgrade": - upgrade_value = value.lower() - - if upgrade_value != b"websocket" or self.ws_protocol_class is None: - msg = "Unsupported upgrade request." - self.logger.warning(msg) - from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol - - if AutoWebSocketsProtocol is None: # pragma: no cover - msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 - self.logger.warning(msg) - self.send_400_response(msg) - return - + def handle_websocket_upgrade(self, event: H11Event) -> None: if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) @@ -280,7 +284,7 @@ def handle_upgrade(self, event: H11Event) -> None: for name, value in self.headers: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") - protocol = self.ws_protocol_class( # type: ignore[call-arg] + protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] config=self.config, server_state=self.server_state ) protocol.connection_made(self.transport) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index f018c59af3..62db4a813c 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -149,6 +149,32 @@ def _unset_keepalive_if_required(self) -> None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None + def _get_upgrade(self) -> Optional[bytes]: + connection = [] + upgrade = None + for name, value in self.headers: + if name == b"connection": + connection = [token.lower().strip() for token in value.split(b",")] + if name == b"upgrade": + upgrade = value.lower() + if b"upgrade" in connection: + return upgrade + return None + + def _should_upgrade_to_ws(self, upgrade: Optional[bytes]) -> bool: + if upgrade == b"websocket" and self.ws_protocol_class is not None: + return True + if self.config.ws == "auto": + msg = "Unsupported upgrade request." + self.logger.warning(msg) + msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + return False + + def _should_upgrade(self) -> bool: + upgrade = self._get_upgrade() + return self._should_upgrade_to_ws(upgrade) + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() @@ -160,25 +186,11 @@ def data_received(self, data: bytes) -> None: self.send_400_response(msg) return except httptools.HttpParserUpgrade: - self.handle_upgrade() - - def handle_upgrade(self) -> None: - upgrade_value = None - for name, value in self.headers: - if name == b"upgrade": - upgrade_value = value.lower() - - if upgrade_value != b"websocket" or self.ws_protocol_class is None: - msg = "Unsupported upgrade request." - self.logger.warning(msg) - from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol - - if AutoWebSocketsProtocol is None: # pragma: no cover - msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 - self.logger.warning(msg) - self.send_400_response(msg) - return + upgrade = self._get_upgrade() + if self._should_upgrade_to_ws(upgrade): + self.handle_websocket_upgrade() + def handle_websocket_upgrade(self) -> None: if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) @@ -189,7 +201,7 @@ def handle_upgrade(self) -> None: for name, value in self.scope["headers"]: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") - protocol = self.ws_protocol_class( # type: ignore[call-arg] + protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] config=self.config, server_state=self.server_state ) protocol.connection_made(self.transport) @@ -244,7 +256,7 @@ def on_headers_complete(self) -> None: self.scope["method"] = method.decode("ascii") if http_version != "1.1": self.scope["http_version"] = http_version - if self.parser.should_upgrade(): + if self.parser.should_upgrade() and self._should_upgrade(): return parsed_url = httptools.parse_url(self.url) raw_path = parsed_url.path @@ -291,7 +303,9 @@ def on_headers_complete(self) -> None: self.pipeline.appendleft((self.cycle, app)) def on_body(self, body: bytes) -> None: - if self.parser.should_upgrade() or self.cycle.response_complete: + if ( + self.parser.should_upgrade() and self._should_upgrade() + ) or self.cycle.response_complete: return self.cycle.body += body if len(self.cycle.body) > HIGH_WATER_LIMIT: @@ -299,7 +313,9 @@ def on_body(self, body: bytes) -> None: self.cycle.message_event.set() def on_message_complete(self) -> None: - if self.parser.should_upgrade() or self.cycle.response_complete: + if ( + self.parser.should_upgrade() and self._should_upgrade() + ) or self.cycle.response_complete: return self.cycle.more_body = False self.cycle.message_event.set()