Skip to content

Commit

Permalink
Upgrade request handling: ignore http/2 and optionally ignore websock…
Browse files Browse the repository at this point in the history
…et (#1661)
  • Loading branch information
Argannor committed Oct 19, 2022
1 parent 4502717 commit 255dcde
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 57 deletions.
2 changes: 1 addition & 1 deletion docs/settings.md
Expand Up @@ -67,7 +67,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other

* `--loop <str>` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*.
* `--http <str>` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to deny all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Please note that this can be used only with the default `websockets` protocol.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Please note that this can be used only with the default `websockets` protocol.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Please note that this can be used only with the default `websockets` protocol.
Expand Down
65 changes: 58 additions & 7 deletions tests/protocols/test_http.py
Expand Up @@ -7,7 +7,7 @@

from tests.response import Response
from uvicorn import Server
from uvicorn.config import Config
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol

Expand All @@ -18,6 +18,7 @@


HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None]
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()

SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""])

Expand Down Expand Up @@ -76,6 +77,18 @@
]
)

UPGRADE_HTTP2_REQUEST = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: upgrade",
b"Upgrade: h2c",
b"Sec-WebSocket-Version: 11",
b"",
b"",
]
)

INVALID_REQUEST_TEMPLATE = b"\r\n".join(
[
b"%s",
Expand Down Expand Up @@ -697,23 +710,61 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls):

@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_unsupported_upgrade_request(protocol_cls):
async def test_supported_upgrade_request(protocol_cls):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 426 " in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_unsupported_ws_upgrade_request(protocol_cls):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="none")
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Unsupported upgrade request." in protocol.transport.buffer
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_supported_upgrade_request(protocol_cls):
async def test_unsupported_ws_upgrade_request_warn_on_auto(
caplog: pytest.LogCaptureFixture, protocol_cls
):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
protocol = get_connected_protocol(app, protocol_cls, ws="auto")
protocol.ws_protocol_class = None
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 426 " in protocol.transport.buffer
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
warnings = [
record.msg
for record in filter(
lambda record: record.levelname == "WARNING", caplog.records
)
]
assert "Unsupported upgrade request." in warnings
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
assert msg in warnings


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.parametrize("ws", WEBSOCKET_PROTOCOLS)
async def test_http2_upgrade_request(protocol_cls, ws):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls, ws=ws)
protocol.data_received(UPGRADE_HTTP2_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer


async def asgi3app(scope, receive, send):
Expand Down
3 changes: 0 additions & 3 deletions uvicorn/config.py
Expand Up @@ -87,10 +87,8 @@
}
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]


SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER


LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
Expand Down Expand Up @@ -159,7 +157,6 @@ def is_dir(path: Path) -> bool:
def resolve_reload_patterns(
patterns_list: List[str], directories_list: List[str]
) -> Tuple[List[str], List[Path]]:

directories: List[Path] = list(set(map(Path, directories_list.copy())))
patterns: List[str] = patterns_list.copy()

Expand Down
52 changes: 28 additions & 24 deletions uvicorn/protocols/http/h11_impl.py
Expand Up @@ -155,6 +155,28 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _get_upgrade(self) -> Optional[bytes]:
connection = []
upgrade = None
for name, value in self.headers:
if name == b"connection":
connection = [token.lower().strip() for token in value.split(b",")]
if name == b"upgrade":
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None

def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

Expand Down Expand Up @@ -204,12 +226,10 @@ def handle_events(self) -> None:
"headers": self.headers,
}

for name, value in self.headers:
if name == b"connection":
tokens = [token.lower().strip() for token in value.split(b",")]
if b"upgrade" in tokens:
self.handle_upgrade(event)
return
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
self.handle_websocket_upgrade(event)
return

# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
Expand Down Expand Up @@ -254,23 +274,7 @@ def handle_events(self) -> None:
self.cycle.more_body = False
self.cycle.message_event.set()

def handle_upgrade(self, event: H11Event) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
upgrade_value = value.lower()

if upgrade_value != b"websocket" or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol

if AutoWebSocketsProtocol is None: # pragma: no cover
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
self.send_400_response(msg)
return

def handle_websocket_upgrade(self, event: H11Event) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
Expand All @@ -280,7 +284,7 @@ def handle_upgrade(self, event: H11Event) -> None:
for name, value in self.headers:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg]
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
Expand Down
60 changes: 38 additions & 22 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -149,6 +149,32 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _get_upgrade(self) -> Optional[bytes]:
connection = []
upgrade = None
for name, value in self.headers:
if name == b"connection":
connection = [token.lower().strip() for token in value.split(b",")]
if name == b"upgrade":
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None

def _should_upgrade_to_ws(self, upgrade: Optional[bytes]) -> bool:
if upgrade == b"websocket" and self.ws_protocol_class is not None:
return True
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False

def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
return self._should_upgrade_to_ws(upgrade)

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

Expand All @@ -160,25 +186,11 @@ def data_received(self, data: bytes) -> None:
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
self.handle_upgrade()

def handle_upgrade(self) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
upgrade_value = value.lower()

if upgrade_value != b"websocket" or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol

if AutoWebSocketsProtocol is None: # pragma: no cover
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
self.send_400_response(msg)
return
upgrade = self._get_upgrade()
if self._should_upgrade_to_ws(upgrade):
self.handle_websocket_upgrade()

def handle_websocket_upgrade(self) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
Expand All @@ -189,7 +201,7 @@ def handle_upgrade(self) -> None:
for name, value in self.scope["headers"]:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg]
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
Expand Down Expand Up @@ -244,7 +256,7 @@ def on_headers_complete(self) -> None:
self.scope["method"] = method.decode("ascii")
if http_version != "1.1":
self.scope["http_version"] = http_version
if self.parser.should_upgrade():
if self.parser.should_upgrade() and self._should_upgrade():
return
parsed_url = httptools.parse_url(self.url)
raw_path = parsed_url.path
Expand Down Expand Up @@ -291,15 +303,19 @@ def on_headers_complete(self) -> None:
self.pipeline.appendleft((self.cycle, app))

def on_body(self, body: bytes) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
self.flow.pause_reading()
self.cycle.message_event.set()

def on_message_complete(self) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()
Expand Down

0 comments on commit 255dcde

Please sign in to comment.