Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix case where url is fragmented in httptools protocol #1263

Merged
merged 5 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/protocols/test_http.py
@@ -1,10 +1,14 @@
import asyncio
import contextlib
import logging
import socket
import threading
import time

import pytest

from tests.response import Response
from uvicorn import Server
from uvicorn.config import Config
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol
Expand Down Expand Up @@ -744,3 +748,46 @@ def test_invalid_http_request(request_line, protocol_cls, caplog, event_loop):
protocol.data_received(request)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer


def test_fragmentation():
def receive_all(sock):
chunks = []
while True:
chunk = sock.recv(1024)
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks)

app = Response("Hello, world", media_type="text/plain")

def send_fragmented_req(path):

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", 8000))
d = (
f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n"
).encode()
split = len(path) // 2
sock.sendall(d[:split])
time.sleep(0.01)
sock.sendall(d[split:])
resp = receive_all(sock)
sock.shutdown(socket.SHUT_RDWR)
sock.close()
return resp

config = Config(app=app, http="httptools")
server = Server(config=config)
t = threading.Thread(target=server.run)
t.daemon = True
t.start()
time.sleep(1) # wait for unicorn to start

path = "/?param=" + "q" * 10
response = send_fragmented_req(path)
bad_response = b"HTTP/1.1 400 Bad Request"
assert bad_response != response[: len(bad_response)]
server.should_exit = True
t.join()
30 changes: 16 additions & 14 deletions uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -74,7 +74,6 @@ def __init__(self, config, server_state, _loop=None):
self.pipeline = deque()

# Per-request state
self.url = None
self.scope = None
self.headers = None
self.expect_100_continue = False
Expand Down Expand Up @@ -182,15 +181,8 @@ def send_400_response(self, msg: str):
self.transport.write(b"".join(content))
self.transport.close()

# Parser callbacks
def on_url(self, url):
method = self.parser.get_method()
parsed_url = httptools.parse_url(url)
raw_path = parsed_url.path
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.url = url
def on_message_begin(self):
self.url = b""
self.expect_100_continue = False
self.headers = []
self.scope = {
Expand All @@ -200,14 +192,14 @@ def on_url(self, url):
"server": self.server,
"client": self.client,
"scheme": self.scheme,
"method": method.decode("ascii"),
"root_path": self.root_path,
"path": path,
"raw_path": raw_path,
"query_string": parsed_url.query if parsed_url.query else b"",
"headers": self.headers,
}

# Parser callbacks
def on_url(self, url):
self.url += url

def on_header(self, name: bytes, value: bytes):
name = name.lower()
if name == b"expect" and value.lower() == b"100-continue":
Expand All @@ -216,10 +208,20 @@ def on_header(self, name: bytes, value: bytes):

def on_headers_complete(self):
http_version = self.parser.get_http_version()
method = self.parser.get_method()
self.scope["method"] = method.decode("ascii")
if http_version != "1.1":
self.scope["http_version"] = http_version
if self.parser.should_upgrade():
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter that "path", "raw_path", "query_string" are now not populated in the upgrade case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be tempted to say no since the handle_upgrade does not need them and the ws protocols will build the scope themselves down the road.
this said it doesn't hurt I think to put it before, both ways pass the tests, would you prefer it that way ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if this is ok for you this way @tomchristie
I added latest master changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm okay with this, yup. 👍

An alternate approach would be to keep the change footprint absolutely as low as possible. I'm almost always in favour of the lowest possible change footprint in PRs, because they're easier to review and lower risk.

In this case we could alternately approach the PR like this...

    def on_message_begin(self):
         self.url = b""

    def on_url(self, url):
         self.url += url

    def on_url_complete(self):
        # This isn't an `httptools` callback, but we need it because `on_url` can actually
        # be called multiple times, and we don't know on each call if it's complete or not.
        # Instead into this method from `on_headers_complete`, so that we've got a single
        # point at which the URL is set.
        ...  # Existing body of `on_url()`

    def on_headers_complete(self):
        self.on_url_complete()
        ...  # Existing body of `on_headers_complete()`

Which would result in a really small changeset. Which as I say, I tend to think is a great thing.

Having said that, it's not a super complex PR. It looks good already, and I don't want to give you extra work, so gonna okay this and then leave the final decision to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will merge that way, I dont have the time !

parsed_url = httptools.parse_url(self.url)
raw_path = parsed_url.path
path = raw_path.decode("ascii")
if "%" in path:
path = urllib.parse.unquote(path)
self.scope["path"] = path
self.scope["raw_path"] = raw_path
self.scope["query_string"] = parsed_url.query if parsed_url.query else b""

# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
Expand Down