From 4ab2550d667651e11229cf7e0489755fe6d998cc Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Thu, 2 Sep 2021 01:37:22 -0700 Subject: [PATCH] Add test for default chunked Host header --- requests/adapters.py | 3 ++- tests/test_lowlevel.py | 47 ++++++++++++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/requests/adapters.py b/requests/adapters.py index d43c9e2403..fe22ff450e 100644 --- a/requests/adapters.py +++ b/requests/adapters.py @@ -458,10 +458,11 @@ def send(self, request, stream=False, timeout=None, verify=True, cert=None, prox low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT) try: + skip_host = 'Host' in request.headers low_conn.putrequest(request.method, url, skip_accept_encoding=True, - skip_host='Host' in request.headers) + skip_host=skip_host) for header, value in request.headers.items(): low_conn.putheader(header, value) diff --git a/tests/test_lowlevel.py b/tests/test_lowlevel.py index 7dcd00151f..7256eda943 100644 --- a/tests/test_lowlevel.py +++ b/tests/test_lowlevel.py @@ -9,6 +9,18 @@ from .utils import override_environ +def echo_response_handler(sock): + """Simple handler that will take request and echo it back to requester.""" + request_content = consume_socket_content(sock, timeout=0.5) + + text_200 = ( + b'HTTP/1.1 200 OK\r\n' + b'Content-Length: %d\r\n\r\n' + b'%s' + ) % (len(request_content), request_content) + sock.send(text_200) + + def test_chunked_upload(): """can safely send generators""" close_server = threading.Event() @@ -48,29 +60,38 @@ def incomplete_chunked_response_handler(sock): def test_chunked_upload_uses_only_specified_host_header(): """Ensure we use only the specified Host header for chunked requests.""" - text_200 = (b'HTTP/1.1 200 OK\r\n' - b'Content-Length: 0\r\n\r\n') - wanted_host = 'sample-host' - expected_header = 'Host: {}'.format(wanted_host).encode('utf-8') - def single_host_resp_handler(sock): - request_content = consume_socket_content(sock, timeout=0.5) - assert expected_header in request_content - assert request_content.count(b'Host: ') == 1 - sock.send(text_200) + close_server = threading.Event() + server = Server(echo_response_handler, wait_to_close_event=close_server) - return request_content + data = iter([b'a', b'b', b'c']) + custom_host = 'sample-host' + + with server as (host, port): + url = 'http://{}:{}/'.format(host, port) + r = requests.post(url, data=data, headers={'Host': custom_host}, stream=True) + close_server.set() # release server block + expected_header = b'Host: %s\r\n' % custom_host.encode('utf-8') + assert expected_header in r.content + assert r.content.count(b'Host: ') == 1 + + +def test_chunked_upload_doesnt_skip_host_header(): + """Ensure we don't omit all Host headers with chunked requests.""" close_server = threading.Event() + server = Server(echo_response_handler, wait_to_close_event=close_server) - server = Server(single_host_resp_handler, wait_to_close_event=close_server) data = iter([b'a', b'b', b'c']) with server as (host, port): + expected_host = '{}:{}'.format(host, port) url = 'http://{}:{}/'.format(host, port) - r = requests.post(url, data=data, headers={'Host': wanted_host}, stream=True) + r = requests.post(url, data=data, stream=True) close_server.set() # release server block - assert r.status_code == 200 + expected_header = b'Host: %s\r\n' % expected_host.encode('utf-8') + assert expected_header in r.content + assert r.content.count(b'Host: ') == 1 def test_conflicting_content_lengths():