Skip to content

Commit

Permalink
Avoid creating a task to do DNS resolution if there is no throttle (#…
Browse files Browse the repository at this point in the history
…8163)

Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко) <sviat@redhat.com>
  • Loading branch information
bdraco and webknjaz committed Feb 20, 2024
1 parent 895fd00 commit 006fbe0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CHANGES/8163.bugfix.rst
@@ -0,0 +1,5 @@
Improved the DNS resolution performance on cache hit
-- by :user:`bdraco`.

This is achieved by avoiding an :mod:`asyncio` task creation
in this case.
50 changes: 36 additions & 14 deletions aiohttp/connector.py
Expand Up @@ -814,6 +814,7 @@ def clear_dns_cache(
async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
{
Expand All @@ -840,8 +841,7 @@ async def _resolve_host(
return res

key = (host, port)

if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
if key in self._cached_hosts and not self._cached_hosts.expired(key):
# get result early, before any await (#4014)
result = self._cached_hosts.next_addrs(key)

Expand All @@ -850,6 +850,39 @@ async def _resolve_host(
await trace.send_dns_cache_hit(host)
return result

#
# If multiple connectors are resolving the same host, we wait
# for the first one to resolve and then use the result for all of them.
# We use a throttle event to ensure that we only resolve the host once
# and then use the result for all the waiters.
#
# In this case we need to create a task to ensure that we can shield
# the task from cancellation as cancelling this lookup should not cancel
# the underlying lookup or else the cancel event will get broadcast to
# all the waiters across all connections.
#
resolved_host_task = asyncio.create_task(
self._resolve_host_with_throttle(key, host, port, traces)
)
try:
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

resolved_host_task.add_done_callback(drop_exception)
raise

async def _resolve_host_with_throttle(
self,
key: Tuple[str, int],
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
Expand Down Expand Up @@ -1136,22 +1169,11 @@ async def _create_direct_connection(
host = host.rstrip(".") + "."
port = req.port
assert port is not None
host_resolved = asyncio.ensure_future(
self._resolve_host(host, port, traces=traces), loop=self._loop
)
try:
# Cancelling this lookup should not cancel the underlying lookup
# or else the cancel event will get broadcast to all the waiters
# across all connections.
hosts = await asyncio.shield(host_resolved)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

host_resolved.add_done_callback(drop_exception)
raise
hosts = await self._resolve_host(host, port, traces=traces)
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
raise
Expand Down
6 changes: 6 additions & 0 deletions tests/test_connector.py
Expand Up @@ -1021,6 +1021,7 @@ async def test_tcp_connector_dns_throttle_requests(
loop.create_task(conn._resolve_host("localhost", 8080))
loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)


Expand All @@ -1032,6 +1033,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop: Any) -
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert r1.exception() == e
assert r2.exception() == e

Expand All @@ -1045,6 +1049,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
loop.create_task(conn._resolve_host("localhost", 8080))
f = loop.create_task(conn._resolve_host("localhost", 8080))

await asyncio.sleep(0)
await asyncio.sleep(0)
await conn.close()

Expand Down Expand Up @@ -1212,6 +1217,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
await asyncio.sleep(0)
await asyncio.sleep(0)
on_dns_cache_hit.assert_called_once_with(
session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost")
)
Expand Down

0 comments on commit 006fbe0

Please sign in to comment.