diff --git a/src/urllib3/connection.py b/src/urllib3/connection.py index c2beb53433..66320a766a 100644 --- a/src/urllib3/connection.py +++ b/src/urllib3/connection.py @@ -43,7 +43,6 @@ class BrokenPipeError(Exception): pass -from ._collections import HTTPHeaderDict from ._version import __version__ from .exceptions import ( ConnectTimeoutError, @@ -52,7 +51,7 @@ class BrokenPipeError(Exception): SystemTimeWarning, ) from .packages.ssl_match_hostname import CertificateError, match_hostname -from .util import SKIP_USER_AGENT, connection +from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection from .util.ssl_ import ( assert_fingerprint, create_urllib3_context, @@ -213,12 +212,20 @@ def putrequest(self, method, url, *args, **kwargs): return _HTTPConnection.putrequest(self, method, url, *args, **kwargs) + def putheader(self, header, *values): + """""" + if SKIP_HEADER not in values: + _HTTPConnection.putheader(self, header, *values) + elif six.ensure_str(header.lower()) not in SKIPPABLE_HEADERS: + raise ValueError( + "urllib3.util.SKIP_HEADER only supports 'Accept-Encoding', 'Host', and 'User-Agent'" + ) + def request(self, method, url, body=None, headers=None): - headers = HTTPHeaderDict(headers if headers is not None else {}) - if "user-agent" not in headers: + if headers is None: + headers = {"User-Agent": _get_default_user_agent()} + elif "user-agent" not in (k.lower() for k in headers): headers["User-Agent"] = _get_default_user_agent() - elif SKIP_USER_AGENT in headers.get_all("user-agent"): - del headers["user-agent"] super(HTTPConnection, self).request(method, url, body=body, headers=headers) def request_chunked(self, method, url, body=None, headers=None): @@ -226,16 +233,14 @@ def request_chunked(self, method, url, body=None, headers=None): Alternative to the common request method, which sends the body with chunked encoding and not as one block """ - headers = HTTPHeaderDict(headers if headers is not None else {}) - skip_accept_encoding = "accept-encoding" in headers - skip_host = "host" in headers + header_keys = set([k.lower() for k in headers or ()]) + skip_accept_encoding = "accept-encoding" in header_keys + skip_host = "host" in header_keys self.putrequest( method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host ) - if "user-agent" not in headers: - headers["User-Agent"] = _get_default_user_agent() - elif SKIP_USER_AGENT in headers.get_all("user-agent"): - del headers["user-agent"] + if "user-agent" not in header_keys: + self.putheader("User-Agent", _get_default_user_agent()) for header, value in headers.items(): self.putheader(header, value) if "transfer-encoding" not in headers: diff --git a/src/urllib3/util/__init__.py b/src/urllib3/util/__init__.py index b928052147..4547fc522b 100644 --- a/src/urllib3/util/__init__.py +++ b/src/urllib3/util/__init__.py @@ -2,7 +2,7 @@ # For backwards compatibility, provide imports that used to be here. from .connection import is_connection_dropped -from .request import SKIP_USER_AGENT, make_headers +from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers from .response import is_fp_closed from .retry import Retry from .ssl_ import ( @@ -44,5 +44,6 @@ "ssl_wrap_socket", "wait_for_read", "wait_for_write", - "SKIP_USER_AGENT", + "SKIP_HEADER", + "SKIPPABLE_HEADERS", ) diff --git a/src/urllib3/util/request.py b/src/urllib3/util/request.py index 0399c736c6..25103383ec 100644 --- a/src/urllib3/util/request.py +++ b/src/urllib3/util/request.py @@ -5,10 +5,13 @@ from ..exceptions import UnrewindableBodyError from ..packages.six import b, integer_types -# Use an invalid User-Agent to represent suppressing of default user agent. -# See https://tools.ietf.org/html/rfc7231#section-5.5.3 and -# https://tools.ietf.org/html/rfc7230#section-3.2.6 -SKIP_USER_AGENT = "@@@SKIP_USER_AGENT@@@" +# Pass as a value within ``headers`` to skip +# emitting some HTTP headers that are added automatically. +# The only headers that are supported are ``Accept-Encoding``, +# ``Host``, and ``User-Agent``. +SKIP_HEADER = "@@@SKIP_HEADER@@@" +SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"]) + ACCEPT_ENCODING = "gzip,deflate" try: import brotli as _unused_module_brotli # noqa: F401 diff --git a/test/with_dummyserver/test_chunked_transfer.py b/test/with_dummyserver/test_chunked_transfer.py index d78813e340..3cef108300 100644 --- a/test/with_dummyserver/test_chunked_transfer.py +++ b/test/with_dummyserver/test_chunked_transfer.py @@ -8,7 +8,7 @@ consume_socket, ) from urllib3 import HTTPConnectionPool -from urllib3.util import SKIP_USER_AGENT +from urllib3.util import SKIP_HEADER from urllib3.util.retry import Retry # Retry failed tests @@ -123,7 +123,7 @@ def test_remove_user_agent_header(self): "GET", "/", chunks, - headers={"User-Agent": SKIP_USER_AGENT}, + headers={"User-Agent": SKIP_HEADER}, chunked=True, ) diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index 2a5b368a03..2b60dd60ec 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -17,7 +17,6 @@ from dummyserver.server import HAS_IPV6_AND_DNS, NoIPv6Warning from dummyserver.testcase import HTTPDummyServerTestCase, SocketDummyServerTestCase from urllib3 import HTTPConnectionPool, encode_multipart_formdata -from urllib3._collections import HTTPHeaderDict from urllib3.connection import _get_default_user_agent from urllib3.exceptions import ( ConnectTimeoutError, @@ -30,7 +29,7 @@ ) from urllib3.packages.six import b, u from urllib3.packages.six.moves.urllib.parse import urlencode -from urllib3.util import SKIP_USER_AGENT +from urllib3.util import SKIP_HEADER, SKIPPABLE_HEADERS from urllib3.util.retry import RequestHistory, Retry from urllib3.util.timeout import Timeout @@ -834,18 +833,18 @@ def test_no_user_agent_header(self): custom_ua = "I'm not a web scraper, what are you talking about?" with HTTPConnectionPool(self.host, self.port) as pool: # Suppress user agent in the request headers. - no_ua_headers = {"User-Agent": SKIP_USER_AGENT} + no_ua_headers = {"User-Agent": SKIP_HEADER} r = pool.request("GET", "/headers", headers=no_ua_headers) request_headers = json.loads(r.data.decode("utf8")) assert "User-Agent" not in request_headers - assert no_ua_headers["User-Agent"] == SKIP_USER_AGENT + assert no_ua_headers["User-Agent"] == SKIP_HEADER # Suppress user agent in the pool headers. pool.headers = no_ua_headers r = pool.request("GET", "/headers") request_headers = json.loads(r.data.decode("utf8")) assert "User-Agent" not in request_headers - assert no_ua_headers["User-Agent"] == SKIP_USER_AGENT + assert no_ua_headers["User-Agent"] == SKIP_HEADER # Request headers override pool headers. pool_headers = {"User-Agent": custom_ua} @@ -853,22 +852,60 @@ def test_no_user_agent_header(self): r = pool.request("GET", "/headers", headers=no_ua_headers) request_headers = json.loads(r.data.decode("utf8")) assert "User-Agent" not in request_headers - assert no_ua_headers["User-Agent"] == SKIP_USER_AGENT + assert no_ua_headers["User-Agent"] == SKIP_HEADER assert pool_headers.get("User-Agent") == custom_ua - # Suppress user agent when multiple user agents are sent - # if 'SKIP_USER_AGENT' is one of the values. - multi_ua_headers = HTTPHeaderDict() - multi_ua_headers.add("User-Agent", custom_ua) - multi_ua_headers.extend(no_ua_headers) - pool.headers = multi_ua_headers - r = pool.request("GET", "/headers") - request_headers = json.loads(r.data.decode("utf8")) - assert "User-Agent" not in request_headers - assert multi_ua_headers.get_all("User-Agent") == [ - custom_ua, - SKIP_USER_AGENT, - ] + @pytest.mark.parametrize( + "accept_encoding", ["Accept-Encoding", "accept-encoding", None] + ) + @pytest.mark.parametrize("host", ["Host", "host", None]) + @pytest.mark.parametrize("user_agent", ["User-Agent", "user-agent", None]) + @pytest.mark.parametrize("chunked", [True, False]) + def test_skip_header(self, accept_encoding, host, user_agent, chunked): + headers = {} + + if accept_encoding is not None: + headers[accept_encoding] = SKIP_HEADER + if host is not None: + headers[host] = SKIP_HEADER + if user_agent is not None: + headers[user_agent] = SKIP_HEADER + + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/headers", headers=headers, chunked=chunked) + request_headers = json.loads(r.data.decode("utf8")) + + if accept_encoding is None: + assert "Accept-Encoding" in request_headers + else: + assert accept_encoding not in request_headers + if host is None: + assert "Host" in request_headers + else: + assert host not in request_headers + if user_agent is None: + assert "User-Agent" in request_headers + else: + assert user_agent not in request_headers + + @pytest.mark.parametrize("header", ["Content-Length", "content-length"]) + @pytest.mark.parametrize("chunked", [True, False]) + def test_skip_header_non_supported(self, header, chunked): + with HTTPConnectionPool(self.host, self.port) as pool: + with pytest.raises(ValueError) as e: + pool.request( + "GET", "/headers", headers={header: SKIP_HEADER}, chunked=chunked + ) + assert ( + str(e.value) + == "urllib3.util.SKIP_HEADER only supports 'Accept-Encoding', 'Host', and 'User-Agent'" + ) + + # Ensure that the error message stays up to date with 'SKIP_HEADER_SUPPORTED_HEADERS' + assert all( + ("'" + header.title() + "'") in str(e.value) + for header in SKIPPABLE_HEADERS + ) def test_bytes_header(self): with HTTPConnectionPool(self.host, self.port) as pool: