Skip to content

Commit

Permalink
try resolver (#245)
Browse files Browse the repository at this point in the history
* try resolver

* some fixes

* some docs

* more coverage

* mocking dns for test
  • Loading branch information
sonic182 committed Apr 28, 2021
1 parent 2af1325 commit aafe467
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 59 deletions.
6 changes: 3 additions & 3 deletions aiosonic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ async def _do_request(urlparsed: ParseResult,
timeouts: Optional[Timeouts],
http2: bool = False) -> HttpResponse:
"""Something."""
async with (await connector.acquire(urlparsed)) as connection:
timeouts = timeouts or connector.timeouts
await connection.connect(urlparsed, verify, ssl, timeouts, http2)
timeouts = timeouts or connector.timeouts
args = urlparsed, verify, ssl, timeouts, http2
async with (await connector.acquire(*args)) as connection:
to_send = headers_data(connection=connection)

if connection.h2conn:
Expand Down
26 changes: 13 additions & 13 deletions aiosonic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ssl
from ssl import SSLContext
from asyncio import wait_for, open_connection
from asyncio import open_connection
from asyncio import StreamReader
from asyncio import StreamWriter
from typing import Dict
Expand All @@ -13,14 +13,11 @@
import h2.events

# from concurrent import futures (unused)
from aiosonic.exceptions import ConnectTimeout
from aiosonic.exceptions import HttpParsingError
from aiosonic.exceptions import TimeoutException
from aiosonic.timeout import Timeouts
from aiosonic.connectors import TCPConnector
from aiosonic.http2 import Http2Handler

from aiosonic.types import ParamsType
from aiosonic.types import ParsedBodyType


Expand All @@ -41,20 +38,17 @@ def __init__(self, connector: TCPConnector) -> None:

async def connect(self,
urlparsed: ParseResult,
dns_info: dict,
verify: bool,
ssl_context: SSLContext,
timeouts: Timeouts,
http2: bool = False) -> None:
"""Connet with timeout."""
try:
await wait_for(self._connect(
urlparsed, verify, ssl_context, http2
), timeout=timeouts.sock_connect)
except TimeoutException:
raise ConnectTimeout()
await self._connect(
urlparsed, verify, ssl_context, dns_info, http2
)

async def _connect(self, urlparsed: ParseResult, verify: bool,
ssl_context: SSLContext, http2: bool) -> None:
ssl_context: SSLContext, dns_info, http2: bool) -> None:
"""Get reader and writer."""
if not urlparsed.hostname:
raise HttpParsingError('missing hostname')
Expand All @@ -71,6 +65,9 @@ async def _connect(self, urlparsed: ParseResult, verify: bool,
def is_closing():
return True # noqa

dns_info_copy = dns_info.copy()
dns_info_copy['server_hostname'] = dns_info_copy.pop('hostname')

if not (self.key and key == self.key and not is_closing()):
self.close()

Expand All @@ -83,10 +80,13 @@ def is_closing():
if not verify:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
del dns_info_copy['server_hostname']
port = urlparsed.port or (443
if urlparsed.scheme == 'https' else 80)
dns_info_copy['port'] = port
self.reader, self.writer = await open_connection(
urlparsed.hostname, port, ssl=ssl_context)
**dns_info_copy, ssl=ssl_context)

self.temp_key = key
await self._connection_made()
Expand Down
55 changes: 47 additions & 8 deletions aiosonic/connectors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Connector stuffs."""

from asyncio import wait_for
import random
from asyncio import sleep as asyncio_sleep
from asyncio import wait_for
from ssl import SSLContext
from typing import Coroutine
from urllib.parse import ParseResult
Expand All @@ -10,10 +10,12 @@
from hyperframe.frame import SettingsFrame

# from concurrent import futures (unused)
from aiosonic.exceptions import ConnectionPoolAcquireTimeout
from aiosonic.exceptions import TimeoutException
from aiosonic.exceptions import (ConnectionPoolAcquireTimeout, ConnectTimeout,
HttpParsingError, TimeoutException)
from aiosonic.pools import SmartPool
from aiosonic.resolver import DefaultResolver
from aiosonic.timeout import Timeouts
from aiosonic.utils import ExpirableCache


class TCPConnector:
Expand All @@ -26,33 +28,62 @@ class TCPConnector:
* **timeouts**: global timeouts to use for connections with this connector. default: :class:`aiosonic.timeout.Timeouts` instance with default args.
* **connection_cls**: connection class to be used. default: :class:`aiosonic.connection.Connection`
* **pool_cls**: pool class to be used. default: :class:`aiosonic.pools.SmartPool`
* **resolver**: resolver to be used. default: :class:`aiosonic.resolver.DefaultResolver`
* **ttl_dns_cache**: ttl in milliseconds for dns cache. default: `10000` 10 seconds
* **use_dns_cache**: Flag to indicate usage of dns cache. default: `True`
"""

def __init__(self,
pool_size: int = 25,
timeouts: Timeouts = None,
connection_cls=None,
pool_cls=None):
pool_cls=None,
resolver=None,
ttl_dns_cache=10000,
use_dns_cache=True):
from aiosonic.connection import Connection # avoid circular dependency
self.pool_size = pool_size
connection_cls = connection_cls or Connection
pool_cls = pool_cls or SmartPool
self.pool = pool_cls(self, pool_size, connection_cls)
self.timeouts = timeouts or Timeouts()
self.resolver = resolver or DefaultResolver()
self.use_dns_cache = use_dns_cache
if self.use_dns_cache:
self.cache = ExpirableCache(512, ttl_dns_cache)

async def acquire(self, urlparsed: ParseResult):
async def acquire(self, urlparsed: ParseResult, verify, ssl, timeouts, http2):
"""Acquire connection."""
if not urlparsed.hostname:
raise HttpParsingError('missing hostname')

# Faster without timeout
if not self.timeouts.pool_acquire:
return await self.pool.acquire(urlparsed)
conn = await self.pool.acquire(urlparsed)
return await self.after_acquire(
urlparsed, conn, verify, ssl, timeouts, http2)

try:
return await wait_for(self.pool.acquire(urlparsed),
conn = await wait_for(self.pool.acquire(urlparsed),
self.timeouts.pool_acquire)
return await self.after_acquire(
urlparsed, conn, verify, ssl, timeouts, http2)
except TimeoutException:
raise ConnectionPoolAcquireTimeout()

async def after_acquire(self, urlparsed, conn, verify, ssl, timeouts, http2):
dns_info = await self.__resolve_dns(
urlparsed.hostname, urlparsed.port)

try:
await wait_for(conn.connect(
urlparsed, dns_info, verify, ssl, http2
), timeout=timeouts.sock_connect)
except TimeoutException:
raise ConnectTimeout()
return conn

async def release(self, conn):
"""Release connection."""
res = self.pool.release(conn)
Expand All @@ -69,3 +100,11 @@ async def wait_free_pool(self):
async def cleanup(self):
"""Cleanup connector connections."""
await self.pool.cleanup()

async def __resolve_dns(self, host: str, port: int):
key = f'{host}-{port}'
dns_data = self.cache.get(key)
if not dns_data:
dns_data = await self.resolver.resolve(host, port)
self.cache.set(key, dns_data)
return random.choice(dns_data)
129 changes: 129 additions & 0 deletions aiosonic/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# copied from aiohttp

import asyncio
import socket
import sys
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type, Union

__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")

try:
import aiodns

# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
except ImportError: # pragma: no cover
aiodns = None

aiodns_default = False


def get_loop():
if sys.version_info >= (3, 7):
return asyncio.get_running_loop()
else:
return asyncio.get_event_loop()


class AbstractResolver(ABC):
"""Abstract DNS resolver."""

@abstractmethod
async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
"""Return IP address for given hostname"""

@abstractmethod
async def close(self) -> None:
"""Release resolver"""


class ThreadedResolver(AbstractResolver):
"""Use Executor for synchronous getaddrinfo() calls, which defaults to
concurrent.futures.ThreadPoolExecutor.
"""

def __init__(self) -> None:
self._loop = get_loop()

async def resolve(
self, hostname: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
infos = await self._loop.getaddrinfo(
hostname,
port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)

hosts = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6 and address[3]: # type: ignore[misc]
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
host, _port = socket.getnameinfo(
address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
)
port = int(_port)
else:
host, port = address[:2]
hosts.append(
{
"hostname": hostname,
"host": host,
"port": port,
"family": family,
"proto": proto,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
)

return hosts

async def close(self) -> None:
pass


class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")

self._loop = get_loop()
self._resolver = aiodns.DNSResolver(*args, loop=self._loop, **kwargs)

async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
try:
resp = await self._resolver.gethostbyname(host, family)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for address in resp.addresses:
hosts.append(
{
"hostname": host,
"host": address,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
)

if not hosts:
raise OSError("DNS lookup failed")

return hosts

async def close(self) -> None:
self._resolver.cancel()


_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver

0 comments on commit aafe467

Please sign in to comment.