Skip to content

Commit

Permalink
Support passing a custom server name parameter on HTTPS connection (a…
Browse files Browse the repository at this point in the history
…io-libs#7541)

This adds the missing support to set the `server_hostname` setting when
creating TCP connection, when the underlying connection is authenticated
using TLS.

See the documentation for the 2 stdlib functions:

*
https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_connection
*
https://docs.python.org/3/library/asyncio-eventloop.html#opening-network-connections

This would be needed to support features in clients using aiohttp, such
as tomplus/kubernetes_asyncio#267

The default behavior should not change, but this would allow on a
per-connection basis to specify a custom server name to check the
certificate name against.

Closes: aio-libs#7114

(for reference, similar implementation in urllib3:
urllib3/urllib3#1397)

- [x] I think the code is well written
- [x] Unit tests for the changes exist
- [x] Documentation reflects the changes
- [x] If you provide code modification, please add yourself to
`CONTRIBUTORS.txt`
  * The format is <Name> <Surname>.
  * Please keep alphabetical order, the file is sorted by names.
- [x] Add a new news fragment into the `CHANGES` folder
  * name it `<issue_id>.<type>` for example (588.bugfix)
* if you don't have an `issue_id` change it to the pr id after creating
the pr
  * ensure type is one of the following:
    * `.feature`: Signifying a new feature.
    * `.bugfix`: Signifying a bug fix.
    * `.doc`: Signifying a documentation improvement.
    * `.removal`: Signifying a deprecation or removal of public API.
* `.misc`: A ticket has been closed, but it is not of interest to users.
* Make sure to use full sentences with correct case and punctuation, for
example: "Fix issue with non-ascii contents in doctest text files."

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sam Bull <aa6bs0@sambull.org>
(cherry picked from commit ac29dea)
  • Loading branch information
multani committed Aug 20, 2023
1 parent bdeca03 commit 7d7aef1
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES/7114.feature
@@ -0,0 +1 @@
Support passing a custom server name parameter to HTTPS connection
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Expand Up @@ -172,6 +172,7 @@ Joel Watts
Jon Nabozny
Jonas Krüger Svensson
Jonas Obrist
Jonathan Ballet
Jonathan Wright
Jonny Tan
Joongi Kim
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Expand Up @@ -398,6 +398,7 @@ async def _request(
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -551,6 +552,7 @@ async def _request(
timer=timer,
session=self,
ssl=ssl,
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
trust_env=self.trust_env,
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client_reqrep.py
Expand Up @@ -277,6 +277,7 @@ def __init__(
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
server_hostname: Optional[str] = None,
):

if loop is None:
Expand Down Expand Up @@ -306,6 +307,7 @@ def __init__(
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self.server_hostname = server_hostname

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand Down
8 changes: 6 additions & 2 deletions aiohttp/connector.py
Expand Up @@ -1101,7 +1101,7 @@ async def _start_tls_connection(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=req.host,
server_hostname=req.server_hostname or req.host,
ssl_handshake_timeout=timeout.total,
)
except BaseException:
Expand Down Expand Up @@ -1183,6 +1183,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
host = hinfo["host"]
port = hinfo["port"]

server_hostname = (
(req.server_hostname or hinfo["hostname"]) if sslcontext else None
)

try:
transp, proto = await self._wrap_create_connection(
self._factory,
Expand All @@ -1193,7 +1197,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
family=hinfo["family"],
proto=hinfo["proto"],
flags=hinfo["flags"],
server_hostname=hinfo["hostname"] if sslcontext else None,
server_hostname=server_hostname,
local_addr=self._local_addr,
req=req,
client_error=client_error,
Expand Down
9 changes: 8 additions & 1 deletion docs/client_reference.rst
Expand Up @@ -365,7 +365,7 @@ The client session supports the context manager protocol for self closing.
timeout=sentinel, ssl=None, \
verify_ssl=None, fingerprint=None, \
ssl_context=None, proxy_headers=None, \
auto_decompress=None)
server_hostname=None, auto_decompress=None)
:async:
:noindexentry:

Expand Down Expand Up @@ -515,6 +515,13 @@ The client session supports the context manager protocol for self closing.

Use ``ssl=aiohttp.Fingerprint(digest)``

:param str server_hostname: Sets or overrides the host name that the
target server’s certificate will be matched against.

See :py:meth:`asyncio.loop.create_connection` for more information.

.. versionadded:: 3.9

:param ssl.SSLContext ssl_context: ssl context used for processing
*HTTPS* requests (optional).

Expand Down
31 changes: 31 additions & 0 deletions tests/test_connector.py
Expand Up @@ -9,6 +9,7 @@
import sys
import uuid
from collections import deque
from contextlib import closing
from unittest import mock

import pytest
Expand Down Expand Up @@ -555,6 +556,36 @@ async def certificate_error(*args, **kwargs):
assert isinstance(ctx.value, aiohttp.ClientSSLError)


async def test_tcp_connector_server_hostname_default(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1"


async def test_tcp_connector_server_hostname_override(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest(
"GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost"
)

with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "localhost"


async def test_tcp_connector_multiple_hosts_errors(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

Expand Down
124 changes: 124 additions & 0 deletions tests/test_proxy.py
Expand Up @@ -191,6 +191,130 @@ async def make_conn():
connector.connect(req, None, aiohttp.ClientTimeout())
)

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET", URL("http://proxy.example.com"), loop=self.loop
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(
self.loop.start_tls.call_args.kwargs["server_hostname"], "www.python.org"
)

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
"GET",
URL("http://proxy.example.com"),
loop=self.loop,
)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse(
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
traces=[],
loop=self.loop,
session=mock.Mock(),
)
proxy_req.send = make_mocked_coro(proxy_resp)
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

async def make_conn():
return aiohttp.TCPConnector()

connector = self.loop.run_until_complete(make_conn())
connector._resolve_host = make_mocked_coro(
[
{
"hostname": "hostname",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": 0,
}
]
)

tr, proto = mock.Mock(), mock.Mock()
self.loop.create_connection = make_mocked_coro((tr, proto))
self.loop.start_tls = make_mocked_coro(mock.Mock())

req = ClientRequest(
"GET",
URL("https://www.python.org"),
proxy=URL("http://proxy.example.com"),
server_hostname="server-hostname.example.com",
loop=self.loop,
)
self.loop.run_until_complete(
connector._create_connection(req, None, aiohttp.ClientTimeout())
)

self.assertEqual(
self.loop.start_tls.call_args.kwargs["server_hostname"],
"server-hostname.example.com",
)

self.loop.run_until_complete(proxy_req.close())
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch("aiohttp.connector.ClientRequest")
def test_https_connect(self, ClientRequestMock) -> None:
proxy_req = ClientRequest(
Expand Down

0 comments on commit 7d7aef1

Please sign in to comment.