diff --git a/CHANGES/7114.feature b/CHANGES/7114.feature new file mode 100644 index 0000000000..697335618a --- /dev/null +++ b/CHANGES/7114.feature @@ -0,0 +1 @@ +Support passing a custom server name parameter to HTTPS connection diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 8a91cb4950..503b7f129c 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -178,6 +178,7 @@ Joel Watts Jon Nabozny Jonas Krüger Svensson Jonas Obrist +Jonathan Ballet Jonathan Wright Jonny Tan Joongi Kim diff --git a/aiohttp/client.py b/aiohttp/client.py index 37b35997f5..314e51b019 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -348,6 +348,7 @@ async def _request( proxy_auth: Optional[BasicAuth] = None, timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel, ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, read_bufsize: Optional[int] = None, @@ -494,6 +495,7 @@ async def _request( timer=timer, session=self, ssl=ssl, + server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, trust_env=self.trust_env, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index d479851396..c9864b3417 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -215,6 +215,7 @@ def __init__( proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, trust_env: bool = False, + server_hostname: Optional[str] = None, ): match = _CONTAINS_CONTROL_CHAR_RE.search(method) if match: @@ -246,6 +247,7 @@ def __init__( self.response_class: Type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() self._ssl = ssl + self.server_hostname = server_hostname if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 6d087d26d2..dfbdc006f2 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1034,7 +1034,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.host, + server_hostname=req.server_hostname or req.host, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1116,6 +1116,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: host = hinfo["host"] port = hinfo["port"] + server_hostname = ( + (req.server_hostname or hinfo["hostname"]) if sslcontext else None + ) + try: transp, proto = await self._wrap_create_connection( self._factory, @@ -1126,7 +1130,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: family=hinfo["family"], proto=hinfo["proto"], flags=hinfo["flags"], - server_hostname=hinfo["hostname"] if sslcontext else None, + server_hostname=server_hostname, local_addr=self._local_addr, req=req, client_error=client_error, diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 05c39b540b..62d74ac04f 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -347,7 +347,7 @@ The client session supports the context manager protocol for self closing. timeout=sentinel, ssl=None, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ - auto_decompress=None) + server_hostname=None, auto_decompress=None) :async: :noindexentry: @@ -497,6 +497,13 @@ The client session supports the context manager protocol for self closing. Use ``ssl=aiohttp.Fingerprint(digest)`` + :param str server_hostname: Sets or overrides the hostname that the + target server’s certificate will be matched against. + + See :method:`asyncio.loop.create_connection` for more information. + + .. versionadded:: 3.9 + :param ssl.SSLContext ssl_context: ssl context used for processing *HTTPS* requests (optional). diff --git a/tests/test_connector.py b/tests/test_connector.py index 9ef71882f8..7e3a345eb8 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -11,6 +11,7 @@ import uuid from collections import deque from typing import Any, Optional +from contextlib import closing from unittest import mock import pytest @@ -554,6 +555,35 @@ async def certificate_error(*args, **kwargs): assert isinstance(ctx.value, aiohttp.ClientSSLError) +async def test_tcp_connector_server_hostname_default(loop: Any) -> None: + conn = aiohttp.TCPConnector() + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + + with closing(await conn.connect(req, [], ClientTimeout())): + assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" + + +async def test_tcp_connector_server_hostname_override(loop: Any) -> None: + conn = aiohttp.TCPConnector() + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest( + "GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost" + ) + + with closing(await conn.connect(req, [], ClientTimeout())): + assert create_connection.call_args.kwargs["server_hostname"] == "localhost" + async def test_tcp_connector_multiple_hosts_errors(loop: Any) -> None: conn = aiohttp.TCPConnector() diff --git a/tests/test_proxy.py b/tests/test_proxy.py index af869ee88f..58396eeb65 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -188,6 +188,130 @@ async def make_conn(): connector.connect(req, None, aiohttp.ClientTimeout()) ) + @mock.patch("aiohttp.connector.ClientRequest") + def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) + + tr, proto = mock.Mock(), mock.Mock() + self.loop.create_connection = make_mocked_coro((tr, proto)) + self.loop.start_tls = make_mocked_coro(mock.Mock()) + + req = ClientRequest( + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + + self.assertEqual( + self.loop.start_tls.call_args.kwargs["server_hostname"], "www.python.org" + ) + + self.loop.run_until_complete(proxy_req.close()) + proxy_resp.close() + self.loop.run_until_complete(req.close()) + + @mock.patch("aiohttp.connector.ClientRequest") + def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", + URL("http://proxy.example.com"), + loop=self.loop, + ) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) + + tr, proto = mock.Mock(), mock.Mock() + self.loop.create_connection = make_mocked_coro((tr, proto)) + self.loop.start_tls = make_mocked_coro(mock.Mock()) + + req = ClientRequest( + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), + server_hostname="server-hostname.example.com", + loop=self.loop, + ) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + + self.assertEqual( + self.loop.start_tls.call_args.kwargs["server_hostname"], + "server-hostname.example.com", + ) + + self.loop.run_until_complete(proxy_req.close()) + proxy_resp.close() + self.loop.run_until_complete(req.close()) + @mock.patch("aiohttp.connector.ClientRequest") def test_https_connect(self, ClientRequestMock: Any) -> None: proxy_req = ClientRequest(