Skip to content

Commit

Permalink
Support passing a custom server name parameter on HTTPS connection
Browse files Browse the repository at this point in the history
This add the missing support to set the `server_hostname` setting when
creating TCP connection, when the underlying connection is authenticated
using TLS.

See the documentation for the 2 stdlib functions:

* https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_connection
* https://docs.python.org/3/library/asyncio-eventloop.html#opening-network-connections

The implemention is similar to what was done in urllib3 in urllib3/urllib3#1397

This would be needed to support features in clients using aiohttp, such as tomplus/kubernetes_asyncio#267

Closes: aio-libs#7114
  • Loading branch information
multani committed Aug 19, 2023
1 parent 9c13a52 commit a6bd740
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES/7114.feature
@@ -0,0 +1 @@
Support passing a custom server name parameter to HTTPS connection
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Expand Up @@ -165,6 +165,7 @@ Joel Watts
Jon Nabozny
Jonas Krüger Svensson
Jonas Obrist
Jonathan Ballet
Jonathan Wright
Jonny Tan
Joongi Kim
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Expand Up @@ -387,6 +387,7 @@ async def _request(
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -525,6 +526,7 @@ async def _request(
timer=timer,
session=self,
ssl=ssl,
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
)
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client_reqrep.py
Expand Up @@ -268,6 +268,7 @@ def __init__(
ssl: Union[SSLContext, bool, Fingerprint, None] = None,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
server_hostname: Optional[str] = None,
):

if loop is None:
Expand Down Expand Up @@ -297,6 +298,7 @@ def __init__(
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self.server_hostname = server_hostname

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand Down
12 changes: 10 additions & 2 deletions aiohttp/connector.py
Expand Up @@ -1085,14 +1085,18 @@ async def _start_tls_connection(
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))

server_hostname = req.host
if req.server_hostname is not None:
server_hostname = req.server_hostname

try:
async with ceil_timeout(timeout.sock_connect):
try:
tls_transport = await self._loop.start_tls(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=req.host,
server_hostname=server_hostname,
ssl_handshake_timeout=timeout.total,
)
except BaseException:
Expand Down Expand Up @@ -1174,6 +1178,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
host = hinfo["host"]
port = hinfo["port"]

server_hostname = hinfo["hostname"] if sslcontext else None
if req.server_hostname is not None:
server_hostname = req.server_hostname

try:
transp, proto = await self._wrap_create_connection(
self._factory,
Expand All @@ -1184,7 +1192,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
family=hinfo["family"],
proto=hinfo["proto"],
flags=hinfo["flags"],
server_hostname=hinfo["hostname"] if sslcontext else None,
server_hostname=server_hostname,
local_addr=self._local_addr,
req=req,
client_error=client_error,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_connector.py
Expand Up @@ -557,6 +557,34 @@ async def certificate_error(*args, **kwargs):
assert isinstance(ctx.value, aiohttp.ClientSSLError)


async def test_tcp_connector_server_hostname_default(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

conn._loop.create_connection = mock.AsyncMock()
conn._loop.create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

established_connection = await conn.connect(req, [], ClientTimeout())
assert conn._loop.create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1"

established_connection.close()


async def test_tcp_connector_server_hostname_override(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

conn._loop.create_connection = mock.AsyncMock()
conn._loop.create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost")

established_connection = await conn.connect(req, [], ClientTimeout())
assert conn._loop.create_connection.call_args.kwargs["server_hostname"] == "localhost"

established_connection.close()


async def test_tcp_connector_multiple_hosts_errors(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

Expand Down
117 changes: 117 additions & 0 deletions tests/test_proxy.py
Expand Up @@ -186,6 +186,123 @@ async def make_conn():
connector.connect(req, None, aiohttp.ClientTimeout())
)

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET", URL("http://proxy.example.com"), loop=self.loop
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(self.loop.start_tls.call_args.kwargs["server_hostname"], "www.python.org")

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET", URL("http://proxy.example.com"), loop=self.loop,
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
server_hostname="server-hostname.example.com",
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(self.loop.start_tls.call_args.kwargs["server_hostname"], "server-hostname.example.com")

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_https_connect(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
Expand Down

0 comments on commit a6bd740

Please sign in to comment.