Skip to content

Commit

Permalink
Fix case where url is fragmented in httptools protocol (encode#1263)
Browse files Browse the repository at this point in the history
* Fix fragmented url

* Fixed subtle bug introduced by setting self.url in the init, it should be reinitiliazed for every new connection

* Blacked

* Adapted failing tests provided in bug report
  • Loading branch information
euri10 authored and Kludex committed Oct 29, 2022
1 parent 7881759 commit 7379c60
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
47 changes: 47 additions & 0 deletions tests/protocols/test_http.py
@@ -1,10 +1,14 @@
import asyncio
import contextlib
import logging
import socket
import threading
import time

import pytest

from tests.response import Response
from uvicorn import Server
from uvicorn.config import Config
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol
Expand Down Expand Up @@ -744,3 +748,46 @@ def test_invalid_http_request(request_line, protocol_cls, caplog, event_loop):
protocol.data_received(request)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer


def test_fragmentation():
def receive_all(sock):
chunks = []
while True:
chunk = sock.recv(1024)
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks)

app = Response("Hello, world", media_type="text/plain")

def send_fragmented_req(path):

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", 8000))
d = (
f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n"
).encode()
split = len(path) // 2
sock.sendall(d[:split])
time.sleep(0.01)
sock.sendall(d[split:])
resp = receive_all(sock)
sock.shutdown(socket.SHUT_RDWR)
sock.close()
return resp

config = Config(app=app, http="httptools")
server = Server(config=config)
t = threading.Thread(target=server.run)
t.daemon = True
t.start()
time.sleep(1) # wait for unicorn to start

path = "/?param=" + "q" * 10
response = send_fragmented_req(path)
bad_response = b"HTTP/1.1 400 Bad Request"
assert bad_response != response[: len(bad_response)]
server.should_exit = True
t.join()
30 changes: 16 additions & 14 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -74,7 +74,6 @@ def __init__(self, config, server_state, _loop=None):
self.pipeline = deque()

# Per-request state
self.url = None
self.scope = None
self.headers = None
self.expect_100_continue = False
Expand Down Expand Up @@ -182,15 +181,8 @@ def send_400_response(self, msg: str):
self.transport.write(b"".join(content))
self.transport.close()

# Parser callbacks
def on_url(self, url):
method = self.parser.get_method()
parsed_url = httptools.parse_url(url)
raw_path = parsed_url.path
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.url = url
def on_message_begin(self):
self.url = b""
self.expect_100_continue = False
self.headers = []
self.scope = {
Expand All @@ -200,14 +192,14 @@ def on_url(self, url):
"server": self.server,
"client": self.client,
"scheme": self.scheme,
"method": method.decode("ascii"),
"root_path": self.root_path,
"path": path,
"raw_path": raw_path,
"query_string": parsed_url.query if parsed_url.query else b"",
"headers": self.headers,
}

# Parser callbacks
def on_url(self, url):
self.url += url

def on_header(self, name: bytes, value: bytes):
name = name.lower()
if name == b"expect" and value.lower() == b"100-continue":
Expand All @@ -216,10 +208,20 @@ def on_header(self, name: bytes, value: bytes):

def on_headers_complete(self):
http_version = self.parser.get_http_version()
method = self.parser.get_method()
self.scope["method"] = method.decode("ascii")
if http_version != "1.1":
self.scope["http_version"] = http_version
if self.parser.should_upgrade():
return
parsed_url = httptools.parse_url(self.url)
raw_path = parsed_url.path
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.scope["path"] = path
self.scope["raw_path"] = raw_path
self.scope["query_string"] = parsed_url.query if parsed_url.query else b""

# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
Expand Down

0 comments on commit 7379c60

Please sign in to comment.