diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index d5f343268..a1eacbb0a 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -1,10 +1,14 @@ import asyncio import contextlib import logging +import socket +import threading +import time import pytest from tests.response import Response +from uvicorn import Server from uvicorn.config import Config from uvicorn.main import ServerState from uvicorn.protocols.http.h11_impl import H11Protocol @@ -744,3 +748,46 @@ def test_invalid_http_request(request_line, protocol_cls, caplog, event_loop): protocol.data_received(request) assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer assert b"Invalid HTTP request received." in protocol.transport.buffer + + +def test_fragmentation(): + def receive_all(sock): + chunks = [] + while True: + chunk = sock.recv(1024) + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + + app = Response("Hello, world", media_type="text/plain") + + def send_fragmented_req(path): + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("127.0.0.1", 8000)) + d = ( + f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n" + ).encode() + split = len(path) // 2 + sock.sendall(d[:split]) + time.sleep(0.01) + sock.sendall(d[split:]) + resp = receive_all(sock) + sock.shutdown(socket.SHUT_RDWR) + sock.close() + return resp + + config = Config(app=app, http="httptools") + server = Server(config=config) + t = threading.Thread(target=server.run) + t.daemon = True + t.start() + time.sleep(1) # wait for unicorn to start + + path = "/?param=" + "q" * 10 + response = send_fragmented_req(path) + bad_response = b"HTTP/1.1 400 Bad Request" + assert bad_response != response[: len(bad_response)] + server.should_exit = True + t.join() diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 838c177bb..eb6cc00bb 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -74,7 +74,6 @@ def __init__(self, config, server_state, _loop=None): self.pipeline = deque() # Per-request state - self.url = None self.scope = None self.headers = None self.expect_100_continue = False @@ -182,15 +181,8 @@ def send_400_response(self, msg: str): self.transport.write(b"".join(content)) self.transport.close() - # Parser callbacks - def on_url(self, url): - method = self.parser.get_method() - parsed_url = httptools.parse_url(url) - raw_path = parsed_url.path - path = raw_path.decode("ascii") - if "%" in path: - path = urllib.parse.unquote(path) - self.url = url + def on_message_begin(self): + self.url = b"" self.expect_100_continue = False self.headers = [] self.scope = { @@ -200,14 +192,14 @@ def on_url(self, url): "server": self.server, "client": self.client, "scheme": self.scheme, - "method": method.decode("ascii"), "root_path": self.root_path, - "path": path, - "raw_path": raw_path, - "query_string": parsed_url.query if parsed_url.query else b"", "headers": self.headers, } + # Parser callbacks + def on_url(self, url): + self.url += url + def on_header(self, name: bytes, value: bytes): name = name.lower() if name == b"expect" and value.lower() == b"100-continue": @@ -216,10 +208,20 @@ def on_header(self, name: bytes, value: bytes): def on_headers_complete(self): http_version = self.parser.get_http_version() + method = self.parser.get_method() + self.scope["method"] = method.decode("ascii") if http_version != "1.1": self.scope["http_version"] = http_version if self.parser.should_upgrade(): return + parsed_url = httptools.parse_url(self.url) + raw_path = parsed_url.path + path = raw_path.decode("ascii") + if "%" in path: + path = urllib.parse.unquote(path) + self.scope["path"] = path + self.scope["raw_path"] = raw_path + self.scope["query_string"] = parsed_url.query if parsed_url.query else b"" # Handle 503 responses when 'limit_concurrency' is exceeded. if self.limit_concurrency is not None and (