Skip to content

Commit

Permalink
Send HTTP 400 response for invalid request
Browse files Browse the repository at this point in the history
Given an invalid request, respond with an HTTP 400 error instead of
closing the connection without a response.
  • Loading branch information
Mark Breedlove committed Sep 28, 2018
1 parent 4b2598e commit 09577e1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 27 deletions.
17 changes: 17 additions & 0 deletions tests/protocols/test_http.py
Expand Up @@ -57,6 +57,12 @@
b""
])

INVALID_REQUEST = b"\r\n".join([
b"GET /?x=y z HTTP/1.1", # bad space character
b"Host: example.org",
b"",
b""
])

class MockTransport:
def __init__(self, sockname=None, peername=None, sslcontext=False):
Expand Down Expand Up @@ -651,3 +657,14 @@ def app(scope):
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Missing Sec-WebSocket-Version header" in protocol.transport.buffer


@pytest.mark.parametrize("protocol_cls", [HttpToolsProtocol, H11Protocol])
def test_invalid_http_request(protocol_cls):
def app(scope):
return Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(INVALID_REQUEST)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
35 changes: 19 additions & 16 deletions uvicorn/protocols/http/h11_impl.py
Expand Up @@ -179,7 +179,7 @@ def handle_events(self):
except h11.RemoteProtocolError:
msg = "Invalid HTTP request received."
self.logger.warning(msg)
self.transport.close()
self.send_400_response(msg)
return
event_type = type(event)

Expand Down Expand Up @@ -268,21 +268,7 @@ def handle_upgrade(self, event):
if upgrade_value != b'websocket' or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
reason = STATUS_PHRASES[400]
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
event = h11.Response(status_code=400, headers=headers, reason=reason)
output = self.conn.send(event)
self.transport.write(output)
event = h11.Data(data=b'Unsupported upgrade request.')
output = self.conn.send(event)
self.transport.write(output)
event = h11.EndOfMessage()
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()
self.send_400_response(msg)
return

self.connections.discard(self)
Expand All @@ -301,6 +287,23 @@ def handle_upgrade(self, event):
protocol.data_received(b''.join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg):
reason = STATUS_PHRASES[400]
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
event = h11.Response(status_code=400, headers=headers, reason=reason)
output = self.conn.send(event)
self.transport.write(output)
event = h11.Data(data=bytes(msg, 'utf-8'))
output = self.conn.send(event)
self.transport.write(output)
event = h11.EndOfMessage()
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()

def on_response_complete(self):
self.state["total_requests"] += 1

Expand Down
26 changes: 15 additions & 11 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -170,7 +170,8 @@ def data_received(self, data):
except httptools.parser.errors.HttpParserError:
msg = "Invalid HTTP request received."
self.logger.warning(msg)
self.transport.close()
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
self.handle_upgrade()

Expand All @@ -183,16 +184,7 @@ def handle_upgrade(self):
if upgrade_value != b'websocket' or self.ws_protocol_class is None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
content = [STATUS_LINE[400], DEFAULT_HEADERS]
content.extend([
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg.encode('ascii')
])
self.transport.write(b"".join(content))
self.transport.close()
self.send_400_response(msg)
return

self.connections.discard(self)
Expand All @@ -212,6 +204,18 @@ def handle_upgrade(self):
protocol.data_received(b''.join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg):
content = [STATUS_LINE[400], DEFAULT_HEADERS]
content.extend([
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode('ascii') + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg.encode('ascii')
])
self.transport.write(b"".join(content))
self.transport.close()

# Parser callbacks
def on_url(self, url):
method = self.parser.get_method()
Expand Down

0 comments on commit 09577e1

Please sign in to comment.