Skip to content

Commit

Permalink
Support passing a custom server name parameter on HTTPS connection
Browse files Browse the repository at this point in the history
This add the missing support to set the `server_hostname` setting when
creating TCP connection, when the underlying connection is authenticated
using TLS.

See the documentation for the 2 stdlib functions:

* https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_connection
* https://docs.python.org/3/library/asyncio-eventloop.html#opening-network-connections

The implemention is similar to what was done in urllib3 in urllib3/urllib3#1397

This would be needed to support features in clients using aiohttp, such as tomplus/kubernetes_asyncio#267

Closes: aio-libs#7114
  • Loading branch information
multani committed Aug 20, 2023
1 parent 0a9bc32 commit 13c4465
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES/7114.feature
@@ -0,0 +1 @@
Support passing a custom server name parameter to HTTPS connection
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Expand Up @@ -178,6 +178,7 @@ Joel Watts
Jon Nabozny
Jonas Krüger Svensson
Jonas Obrist
Jonathan Ballet
Jonathan Wright
Jonny Tan
Joongi Kim
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Expand Up @@ -348,6 +348,7 @@ async def _request(
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -494,6 +495,7 @@ async def _request(
timer=timer,
session=self,
ssl=ssl,
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
trust_env=self.trust_env,
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client_reqrep.py
Expand Up @@ -215,6 +215,7 @@ def __init__(
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
server_hostname: Optional[str] = None,
):
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
Expand Down Expand Up @@ -246,6 +247,7 @@ def __init__(
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self.server_hostname = server_hostname

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand Down
8 changes: 6 additions & 2 deletions aiohttp/connector.py
Expand Up @@ -1034,7 +1034,7 @@ async def _start_tls_connection(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=req.host,
server_hostname=req.server_hostname or req.host,
ssl_handshake_timeout=timeout.total,
)
except BaseException:
Expand Down Expand Up @@ -1116,6 +1116,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
host = hinfo["host"]
port = hinfo["port"]

server_hostname = (
(req.server_hostname or hinfo["hostname"]) if sslcontext else None
)

try:
transp, proto = await self._wrap_create_connection(
self._factory,
Expand All @@ -1126,7 +1130,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
family=hinfo["family"],
proto=hinfo["proto"],
flags=hinfo["flags"],
server_hostname=hinfo["hostname"] if sslcontext else None,
server_hostname=server_hostname,
local_addr=self._local_addr,
req=req,
client_error=client_error,
Expand Down
9 changes: 8 additions & 1 deletion docs/client_reference.rst
Expand Up @@ -347,7 +347,7 @@ The client session supports the context manager protocol for self closing.
timeout=sentinel, ssl=None, \
verify_ssl=None, fingerprint=None, \
ssl_context=None, proxy_headers=None, \
auto_decompress=None)
server_hostname=None, auto_decompress=None)
:async:
:noindexentry:

Expand Down Expand Up @@ -497,6 +497,13 @@ The client session supports the context manager protocol for self closing.

Use ``ssl=aiohttp.Fingerprint(digest)``

:param str server_hostname: Sets or overrides the hostname that the
target server’s certificate will be matched against.

See :method:`asyncio.loop.create_connection` for more information.

.. versionadded:: 3.9

:param ssl.SSLContext ssl_context: ssl context used for processing
*HTTPS* requests (optional).

Expand Down
30 changes: 30 additions & 0 deletions tests/test_connector.py
Expand Up @@ -11,6 +11,7 @@
import uuid
from collections import deque
from typing import Any, Optional
from contextlib import closing
from unittest import mock

import pytest
Expand Down Expand Up @@ -554,6 +555,35 @@ async def certificate_error(*args, **kwargs):
assert isinstance(ctx.value, aiohttp.ClientSSLError)


async def test_tcp_connector_server_hostname_default(loop: Any) -> None:
conn = aiohttp.TCPConnector()

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1"


async def test_tcp_connector_server_hostname_override(loop: Any) -> None:
conn = aiohttp.TCPConnector()

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest(
"GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost"
)

with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "localhost"

async def test_tcp_connector_multiple_hosts_errors(loop: Any) -> None:
conn = aiohttp.TCPConnector()

Expand Down
124 changes: 124 additions & 0 deletions tests/test_proxy.py
Expand Up @@ -188,6 +188,130 @@ async def make_conn():
connector.connect(req, None, aiohttp.ClientTimeout())
)

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET", URL("http://proxy.example.com"), loop=self.loop
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(
self.loop.start_tls.call_args.kwargs["server_hostname"], "www.python.org"
)

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET",
URL("http://proxy.example.com"),
loop=self.loop,
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
server_hostname="server-hostname.example.com",
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(
self.loop.start_tls.call_args.kwargs["server_hostname"],
"server-hostname.example.com",
)

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_https_connect(self, ClientRequestMock: Any) -> None:
proxy_req = ClientRequest(
Expand Down

0 comments on commit 13c4465

Please sign in to comment.