Skip to content

Commit

Permalink
Changes after @Dreamsorcerer's review
Browse files Browse the repository at this point in the history
  • Loading branch information
multani committed Aug 19, 2023
1 parent 4204af7 commit c21592f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 30 deletions.
10 changes: 2 additions & 8 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,18 +1085,14 @@ async def _start_tls_connection(
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))

server_hostname = req.host
if req.server_hostname is not None:
server_hostname = req.server_hostname

try:
async with ceil_timeout(timeout.sock_connect):
try:
tls_transport = await self._loop.start_tls(
underlying_transport,
tls_proto,
sslcontext,
server_hostname=server_hostname,
server_hostname=req.server_hostname or req.host,
ssl_handshake_timeout=timeout.total,
)
except BaseException:
Expand Down Expand Up @@ -1178,9 +1174,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
host = hinfo["host"]
port = hinfo["port"]

server_hostname = hinfo["hostname"] if sslcontext else None
if req.server_hostname is not None:
server_hostname = req.server_hostname
server_hostname = (req.server_hostname or hinfo["hostname"]) if sslcontext else None

try:
transp, proto = await self._wrap_create_connection(
Expand Down
4 changes: 2 additions & 2 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ The client session supports the context manager protocol for self closing.

:param str server_hostname: Sets or overrides the hostname that the
target server’s certificate will be matched against.
See `upstream documentation <https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_connection>`_
for more information.

See :method:`asyncio.loop.create_connection` for more information.

.. versionadded:: 3.9

Expand Down
37 changes: 17 additions & 20 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import uuid
from collections import deque
from contextlib import closing
from unittest import mock

import pytest
Expand Down Expand Up @@ -560,35 +561,31 @@ async def certificate_error(*args, **kwargs):
async def test_tcp_connector_server_hostname_default(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

conn._loop.create_connection = mock.AsyncMock()
conn._loop.create_connection.return_value = mock.Mock(), mock.Mock()
with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

established_connection = await conn.connect(req, [], ClientTimeout())
assert (
conn._loop.create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1"
)
req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

established_connection.close()
with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1"


async def test_tcp_connector_server_hostname_override(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)

conn._loop.create_connection = mock.AsyncMock()
conn._loop.create_connection.return_value = mock.Mock(), mock.Mock()
with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest(
"GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost"
)

established_connection = await conn.connect(req, [], ClientTimeout())
assert (
conn._loop.create_connection.call_args.kwargs["server_hostname"] == "localhost"
)
req = ClientRequest(
"GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost"
)

established_connection.close()
with closing(await conn.connect(req, [], ClientTimeout())):
assert create_connection.call_args.kwargs["server_hostname"] == "localhost"


async def test_tcp_connector_multiple_hosts_errors(loop) -> None:
Expand Down

0 comments on commit c21592f

Please sign in to comment.