Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade request handling: ignore http/2 and optionally ignore websocket #1661

Merged
merged 13 commits into from Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/settings.md
Expand Up @@ -67,7 +67,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other

* `--loop <str>` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*.
* `--http <str>` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to deny all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Please note that this can be used only with the default `websockets` protocol.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Please note that this can be used only with the default `websockets` protocol.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Please note that this can be used only with the default `websockets` protocol.
Expand Down
68 changes: 61 additions & 7 deletions tests/protocols/test_http.py
Expand Up @@ -2,12 +2,13 @@
import socket
import threading
import time
from typing import List

import pytest

from tests.response import Response
from uvicorn import Server
from uvicorn.config import Config
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol

Expand All @@ -18,6 +19,7 @@


HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None]
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()

SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""])

Expand Down Expand Up @@ -76,6 +78,18 @@
]
)

UPGRADE_HTTP2_REQUEST = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: upgrade",
b"Upgrade: h2c",
b"Sec-WebSocket-Version: 11",
b"",
b"",
]
)

INVALID_REQUEST_TEMPLATE = b"\r\n".join(
[
b"%s",
Expand Down Expand Up @@ -162,6 +176,15 @@ def run_later(self, with_delay):
self._later = later


class MockLogger(logging.Logger):
def __init__(self, name: str):
super().__init__(name)
self.warnings: List[str] = []

def warning(self, msg: str):
self.warnings += [msg]


class MockTask:
def add_done_callback(self, callback):
pass
Expand Down Expand Up @@ -697,23 +720,54 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls):

@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_unsupported_upgrade_request(protocol_cls):
async def test_supported_upgrade_request(protocol_cls):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 426 " in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_unsupported_ws_upgrade_request(protocol_cls):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="none")
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Unsupported upgrade request." in protocol.transport.buffer
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_supported_upgrade_request(protocol_cls):
async def test_unsupported_ws_upgrade_request_warn_on_auto(protocol_cls):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
protocol = get_connected_protocol(app, protocol_cls, ws="auto")
protocol.ws_protocol_class = None
protocol.logger = MockLogger(protocol.logger.name)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 426 " in protocol.transport.buffer
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
assert "Unsupported upgrade request." in protocol.logger.warnings
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
assert msg in protocol.logger.warnings


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.parametrize("ws", WEBSOCKET_PROTOCOLS)
async def test_http2_upgrade_request(protocol_cls, ws):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws=ws)
protocol.data_received(UPGRADE_HTTP2_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer


async def asgi3app(scope, receive, send):
Expand Down
3 changes: 0 additions & 3 deletions uvicorn/config.py
Expand Up @@ -87,10 +87,8 @@
}
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]


SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER


LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
Expand Down Expand Up @@ -159,7 +157,6 @@ def is_dir(path: Path) -> bool:
def resolve_reload_patterns(
patterns_list: List[str], directories_list: List[str]
) -> Tuple[List[str], List[Path]]:

directories: List[Path] = list(set(map(Path, directories_list.copy())))
patterns: List[str] = patterns_list.copy()

Expand Down
52 changes: 28 additions & 24 deletions uvicorn/protocols/http/h11_impl.py
Expand Up @@ -155,6 +155,28 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _get_upgrade(self) -> Optional[bytes]:
connection = []
upgrade = None
for name, value in self.headers:
if name == b"connection":
connection = [token.lower().strip() for token in value.split(b",")]
if name == b"upgrade":
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None

def _should_upgrade_to_ws(self) -> bool:
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the one in the other protocol is implemented in one way, and this one in another? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both protocols work a bit different, and in the case of httptools the "check and return" is repeated in several places, but the actual upgrade handling is only done once in the except block. While the h11 implementation only does it in one single place and both the returning and handling is in the same place.

if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

Expand Down Expand Up @@ -204,12 +226,10 @@ def handle_events(self) -> None:
"headers": self.headers,
}

for name, value in self.headers:
if name == b"connection":
tokens = [token.lower().strip() for token in value.split(b",")]
if b"upgrade" in tokens:
self.handle_upgrade(event)
return
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
self.handle_websocket_upgrade(event)
return

# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
Expand Down Expand Up @@ -254,23 +274,7 @@ def handle_events(self) -> None:
self.cycle.more_body = False
self.cycle.message_event.set()

def handle_upgrade(self, event: H11Event) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
upgrade_value = value.lower()

if upgrade_value != b"websocket" or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol

if AutoWebSocketsProtocol is None: # pragma: no cover
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
self.send_400_response(msg)
return

def handle_websocket_upgrade(self, event: H11Event) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
Expand All @@ -280,7 +284,7 @@ def handle_upgrade(self, event: H11Event) -> None:
for name, value in self.headers:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg]
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
Expand Down
60 changes: 38 additions & 22 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -149,6 +149,32 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _get_upgrade(self) -> Optional[bytes]:
connection = []
upgrade = None
for name, value in self.headers:
if name == b"connection":
connection = [token.lower().strip() for token in value.split(b",")]
if name == b"upgrade":
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None

def _should_upgrade_to_ws(self, upgrade: Optional[bytes]) -> bool:
if upgrade == b"websocket" and self.ws_protocol_class is not None:
return True
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False

def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
return self._should_upgrade_to_ws(upgrade)

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

Expand All @@ -160,25 +186,11 @@ def data_received(self, data: bytes) -> None:
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
self.handle_upgrade()

def handle_upgrade(self) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
upgrade_value = value.lower()

if upgrade_value != b"websocket" or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol

if AutoWebSocketsProtocol is None: # pragma: no cover
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
self.send_400_response(msg)
return
upgrade = self._get_upgrade()
if self._should_upgrade_to_ws(upgrade):
self.handle_websocket_upgrade()

def handle_websocket_upgrade(self) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
Expand All @@ -189,7 +201,7 @@ def handle_upgrade(self) -> None:
for name, value in self.scope["headers"]:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg]
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
Kludex marked this conversation as resolved.
Show resolved Hide resolved
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
Expand Down Expand Up @@ -244,7 +256,7 @@ def on_headers_complete(self) -> None:
self.scope["method"] = method.decode("ascii")
if http_version != "1.1":
self.scope["http_version"] = http_version
if self.parser.should_upgrade():
if self.parser.should_upgrade() and self._should_upgrade():
Kludex marked this conversation as resolved.
Show resolved Hide resolved
return
parsed_url = httptools.parse_url(self.url)
raw_path = parsed_url.path
Expand Down Expand Up @@ -291,15 +303,19 @@ def on_headers_complete(self) -> None:
self.pipeline.appendleft((self.cycle, app))

def on_body(self, body: bytes) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
self.flow.pause_reading()
self.cycle.message_event.set()

def on_message_complete(self) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()
Expand Down