Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send HTTP 400 response for invalid request #205

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/protocols/test_http.py
Expand Up @@ -57,6 +57,12 @@
b""
])

INVALID_REQUEST = b"\r\n".join([
b"GET /?x=y z HTTP/1.1", # bad space character
b"Host: example.org",
b"",
b""
])

class MockTransport:
def __init__(self, sockname=None, peername=None, sslcontext=False):
Expand Down Expand Up @@ -651,3 +657,14 @@ def app(scope):
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Missing Sec-WebSocket-Version header" in protocol.transport.buffer


@pytest.mark.parametrize("protocol_cls", [HttpToolsProtocol, H11Protocol])
def test_invalid_http_request(protocol_cls):
def app(scope):
return Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(INVALID_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
35 changes: 19 additions & 16 deletions uvicorn/protocols/http/h11_impl.py
Expand Up @@ -179,7 +179,7 @@ def handle_events(self):
except h11.RemoteProtocolError as exc:
msg = "Invalid HTTP request received."
self.logger.warning(msg)
self.transport.close()
self.send_400_response(msg)
return
event_type = type(event)

Expand Down Expand Up @@ -268,21 +268,7 @@ def handle_upgrade(self, event):
if upgrade_value != b'websocket' or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
reason = STATUS_PHRASES[400]
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
event = h11.Response(status_code=400, headers=headers, reason=reason)
output = self.conn.send(event)
self.transport.write(output)
event = h11.Data(data=b'Unsupported upgrade request.')
output = self.conn.send(event)
self.transport.write(output)
event = h11.EndOfMessage()
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()
self.send_400_response(msg)
return

self.connections.discard(self)
Expand All @@ -301,6 +287,23 @@ def handle_upgrade(self, event):
protocol.data_received(b''.join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg):
reason = STATUS_PHRASES[400]
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
event = h11.Response(status_code=400, headers=headers, reason=reason)
output = self.conn.send(event)
self.transport.write(output)
event = h11.Data(data=bytes(msg, 'utf-8'))
output = self.conn.send(event)
self.transport.write(output)
event = h11.EndOfMessage()
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()

def on_response_complete(self):
self.state["total_requests"] += 1

Expand Down
26 changes: 15 additions & 11 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -169,7 +169,8 @@ def data_received(self, data):
except httptools.parser.errors.HttpParserError as exc:
msg = "Invalid HTTP request received."
self.logger.warning(msg)
self.transport.close()
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade as exc:
self.handle_upgrade()

Expand All @@ -182,16 +183,7 @@ def handle_upgrade(self):
if upgrade_value != b'websocket' or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
content = [STATUS_LINE[400], DEFAULT_HEADERS]
content.extend([
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg.encode('ascii')
])
self.transport.write(b"".join(content))
self.transport.close()
self.send_400_response(msg)
return

self.connections.discard(self)
Expand All @@ -211,6 +203,18 @@ def handle_upgrade(self):
protocol.data_received(b''.join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg):
content = [STATUS_LINE[400], DEFAULT_HEADERS]
content.extend([
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg.encode('ascii')
])
self.transport.write(b"".join(content))
self.transport.close()

# Parser callbacks
def on_url(self, url):
method = self.parser.get_method()
Expand Down