diff --git a/CHANGES/7114.feature b/CHANGES/7114.feature new file mode 100644 index 0000000000..697335618a --- /dev/null +++ b/CHANGES/7114.feature @@ -0,0 +1 @@ +Support passing a custom server name parameter to HTTPS connection diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 6c2fabbdec..e4b8d1805e 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -165,6 +165,7 @@ Joel Watts Jon Nabozny Jonas Krüger Svensson Jonas Obrist +Jonathan Ballet Jonathan Wright Jonny Tan Joongi Kim diff --git a/aiohttp/client.py b/aiohttp/client.py index 0d0f4c16c0..afb033fb09 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -387,6 +387,7 @@ async def _request( fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None, + server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, read_bufsize: Optional[int] = None, @@ -525,6 +526,7 @@ async def _request( timer=timer, session=self, ssl=ssl, + server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, ) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 28b8a28d0d..07cff00402 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -268,6 +268,7 @@ def __init__( ssl: Union[SSLContext, bool, Fingerprint, None] = None, proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, + server_hostname: Optional[str] = None, ): if loop is None: @@ -297,6 +298,7 @@ def __init__( self.response_class: Type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() self._ssl = ssl + self.server_hostname = server_hostname if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 2499a2dabe..5bfb17f057 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1085,6 +1085,10 @@ async def _start_tls_connection( # maintainability-wise but this is to be solved separately. sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req)) + server_hostname = req.host + if req.server_hostname is not None: + server_hostname = req.server_hostname + try: async with ceil_timeout(timeout.sock_connect): try: @@ -1092,7 +1096,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.host, + server_hostname=server_hostname, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1174,6 +1178,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: host = hinfo["host"] port = hinfo["port"] + server_hostname = hinfo["hostname"] if sslcontext else None + if req.server_hostname is not None: + server_hostname = req.server_hostname + try: transp, proto = await self._wrap_create_connection( self._factory, @@ -1184,7 +1192,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: family=hinfo["family"], proto=hinfo["proto"], flags=hinfo["flags"], - server_hostname=hinfo["hostname"] if sslcontext else None, + server_hostname=server_hostname, local_addr=self._local_addr, req=req, client_error=client_error, diff --git a/tests/test_connector.py b/tests/test_connector.py index 0b992df98c..bb38605cdd 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -557,6 +557,34 @@ async def certificate_error(*args, **kwargs): assert isinstance(ctx.value, aiohttp.ClientSSLError) +async def test_tcp_connector_server_hostname_default(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + + conn._loop.create_connection = mock.AsyncMock() + conn._loop.create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + + established_connection = await conn.connect(req, [], ClientTimeout()) + assert conn._loop.create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" + + established_connection.close() + + +async def test_tcp_connector_server_hostname_override(loop) -> None: + conn = aiohttp.TCPConnector(loop=loop) + + conn._loop.create_connection = mock.AsyncMock() + conn._loop.create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost") + + established_connection = await conn.connect(req, [], ClientTimeout()) + assert conn._loop.create_connection.call_args.kwargs["server_hostname"] == "localhost" + + established_connection.close() + + async def test_tcp_connector_multiple_hosts_errors(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 25bcb647fa..66b1385918 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -186,6 +186,123 @@ async def make_conn(): connector.connect(req, None, aiohttp.ClientTimeout()) ) + @mock.patch("aiohttp.connector.ClientRequest") + def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop + ) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) + + tr, proto = mock.Mock(), mock.Mock() + self.loop.create_connection = make_mocked_coro((tr, proto)) + self.loop.start_tls = make_mocked_coro(mock.Mock()) + + req = ClientRequest( + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), + loop=self.loop, + ) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + + self.assertEqual(self.loop.start_tls.call_args.kwargs["server_hostname"], "www.python.org") + + self.loop.run_until_complete(proxy_req.close()) + proxy_resp.close() + self.loop.run_until_complete(req.close()) + + @mock.patch("aiohttp.connector.ClientRequest") + def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: + proxy_req = ClientRequest( + "GET", URL("http://proxy.example.com"), loop=self.loop, + ) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse( + "get", + URL("http://proxy.example.com"), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=self.loop, + session=mock.Mock(), + ) + proxy_req.send = make_mocked_coro(proxy_resp) + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + async def make_conn(): + return aiohttp.TCPConnector() + + connector = self.loop.run_until_complete(make_conn()) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) + + tr, proto = mock.Mock(), mock.Mock() + self.loop.create_connection = make_mocked_coro((tr, proto)) + self.loop.start_tls = make_mocked_coro(mock.Mock()) + + req = ClientRequest( + "GET", + URL("https://www.python.org"), + proxy=URL("http://proxy.example.com"), + server_hostname="server-hostname.example.com", + loop=self.loop, + ) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout()) + ) + + self.assertEqual(self.loop.start_tls.call_args.kwargs["server_hostname"], "server-hostname.example.com") + + self.loop.run_until_complete(proxy_req.close()) + proxy_resp.close() + self.loop.run_until_complete(req.close()) + @mock.patch("aiohttp.connector.ClientRequest") def test_https_connect(self, ClientRequestMock) -> None: proxy_req = ClientRequest(