diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index c1cd9dc4c..08a51be3a 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -57,6 +57,12 @@ b"" ]) +INVALID_REQUEST = b"\r\n".join([ + b"GET /?x=y z HTTP/1.1", # bad space character + b"Host: example.org", + b"", + b"" +]) class MockTransport: def __init__(self, sockname=None, peername=None, sslcontext=False): @@ -651,3 +657,14 @@ def app(scope): protocol.data_received(UPGRADE_REQUEST) assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer assert b"Missing Sec-WebSocket-Version header" in protocol.transport.buffer + + +@pytest.mark.parametrize("protocol_cls", [HttpToolsProtocol, H11Protocol]) +def test_invalid_http_request(protocol_cls): + def app(scope): + return Response("Hello, world", media_type="text/plain") + + protocol = get_connected_protocol(app, protocol_cls) + protocol.data_received(INVALID_REQUEST) + assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer + assert b"Invalid HTTP request received." in protocol.transport.buffer diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 91269d735..a80458519 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -179,7 +179,7 @@ def handle_events(self): except h11.RemoteProtocolError as exc: msg = "Invalid HTTP request received." self.logger.warning(msg) - self.transport.close() + self.send_400_response(msg) return event_type = type(event) @@ -268,21 +268,7 @@ def handle_upgrade(self, event): if upgrade_value != b'websocket' or self.ws_protocol_class is None: msg = "Unsupported upgrade request." self.logger.warning(msg) - reason = STATUS_PHRASES[400] - headers = [ - (b"content-type", b"text/plain; charset=utf-8"), - (b"connection", b"close"), - ] - event = h11.Response(status_code=400, headers=headers, reason=reason) - output = self.conn.send(event) - self.transport.write(output) - event = h11.Data(data=b'Unsupported upgrade request.') - output = self.conn.send(event) - self.transport.write(output) - event = h11.EndOfMessage() - output = self.conn.send(event) - self.transport.write(output) - self.transport.close() + self.send_400_response(msg) return self.connections.discard(self) @@ -301,6 +287,23 @@ def handle_upgrade(self, event): protocol.data_received(b''.join(output)) self.transport.set_protocol(protocol) + def send_400_response(self, msg): + reason = STATUS_PHRASES[400] + headers = [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ] + event = h11.Response(status_code=400, headers=headers, reason=reason) + output = self.conn.send(event) + self.transport.write(output) + event = h11.Data(data=bytes(msg, 'utf-8')) + output = self.conn.send(event) + self.transport.write(output) + event = h11.EndOfMessage() + output = self.conn.send(event) + self.transport.write(output) + self.transport.close() + def on_response_complete(self): self.state["total_requests"] += 1 diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index bfffe9ede..39423d1bb 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -169,7 +169,8 @@ def data_received(self, data): except httptools.parser.errors.HttpParserError as exc: msg = "Invalid HTTP request received." self.logger.warning(msg) - self.transport.close() + self.send_400_response(msg) + return except httptools.HttpParserUpgrade as exc: self.handle_upgrade() @@ -182,16 +183,7 @@ def handle_upgrade(self): if upgrade_value != b'websocket' or self.ws_protocol_class is None: msg = "Unsupported upgrade request." self.logger.warning(msg) - content = [STATUS_LINE[400], DEFAULT_HEADERS] - content.extend([ - b"content-type: text/plain; charset=utf-8\r\n", - b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n", - b"connection: close\r\n", - b"\r\n", - msg.encode('ascii') - ]) - self.transport.write(b"".join(content)) - self.transport.close() + self.send_400_response(msg) return self.connections.discard(self) @@ -211,6 +203,18 @@ def handle_upgrade(self): protocol.data_received(b''.join(output)) self.transport.set_protocol(protocol) + def send_400_response(self, msg): + content = [STATUS_LINE[400], DEFAULT_HEADERS] + content.extend([ + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg.encode('ascii') + ]) + self.transport.write(b"".join(content)) + self.transport.close() + # Parser callbacks def on_url(self, url): method = self.parser.get_method()