Skip to content

Commit

Permalink
Replace h3 parameter with http_version. (#1068)
Browse files Browse the repository at this point in the history
This allows more flexibility; clients can specify which http version
they want, or use the default.
  • Loading branch information
bwelling committed Mar 21, 2024
1 parent 23a6cd2 commit 3238267
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 22 deletions.
14 changes: 9 additions & 5 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
BadResponse,
NoDOH,
NoDOQ,
HTTPVersion,
UDPMode,
_check_status,
_compute_times,
Expand Down Expand Up @@ -533,7 +534,7 @@ async def https(
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: int = socket.AF_UNSPEC,
h3: bool = False,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
Expand All @@ -559,7 +560,7 @@ async def https(
else:
url = where

if h3:
if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh):
if bootstrap_address is None:
parsed = urllib.parse.urlparse(url)
resolver = _maybe_get_resolver(resolver)
Expand Down Expand Up @@ -595,6 +596,9 @@ async def https(
transport = None
headers = {"accept": "application/dns-message"}

h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)

backend = dns.asyncbackend.get_default_backend()

if source is None:
Expand All @@ -605,8 +609,8 @@ async def https(
local_port = source_port
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=True,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
Expand All @@ -618,7 +622,7 @@ async def https(
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=True, verify=verify, transport=transport
http1=h1, http2=h2, verify=verify, transport=transport
)

async with cm as the_client:
Expand Down
8 changes: 4 additions & 4 deletions dns/nameserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ def __init__(
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
h3: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.h3 = h3
self.http_version = http_version

def kind(self):
return "DoH"
Expand Down Expand Up @@ -216,7 +216,7 @@ def query(
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
h3=self.h3,
http_version=self.http_version,
)

async def async_query(
Expand All @@ -241,7 +241,7 @@ async def async_query(
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
h3=self.h3,
http_version=self.http_version,
)


Expand Down
31 changes: 25 additions & 6 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ def _maybe_get_resolver(
return resolver


class HTTPVersion(enum.IntEnum):
"""Which version of HTTP should be used?
DEFAULT will select the first version from the list [2, 1.1, 3] that
is available.
"""

DEFAULT = 0
HTTP_1 = 1
H1 = 1
HTTP_2 = 2
H2 = 2
HTTP_3 = 3
H3 = 3


def https(
q: dns.message.Message,
where: str,
Expand All @@ -367,7 +383,7 @@ def https(
verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: int = socket.AF_UNSPEC,
h3: bool = False,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
Expand Down Expand Up @@ -417,7 +433,7 @@ def https(
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
*h3*, a ``bool``. If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1.
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
Returns a ``dns.message.Message``.
"""
Expand All @@ -433,7 +449,7 @@ def https(
else:
url = where

if h3:
if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh):
if bootstrap_address is None:
parsed = urllib.parse.urlparse(url)
resolver = _maybe_get_resolver(resolver)
Expand Down Expand Up @@ -469,6 +485,9 @@ def https(
transport = None
headers = {"accept": "application/dns-message"}

h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)

# set source port and source address

if the_source is None:
Expand All @@ -479,8 +498,8 @@ def https(
local_port = the_source[1]
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=True,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
Expand All @@ -491,7 +510,7 @@ def https(
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
Expand Down
6 changes: 3 additions & 3 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ async def run():
post=False,
timeout=4,
family=family,
h3=True,
http_version=dns.asyncquery.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand All @@ -587,7 +587,7 @@ async def run():
post=True,
timeout=4,
family=family,
h3=True,
http_version=dns.asyncquery.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand All @@ -603,7 +603,7 @@ async def run():
nameserver_ip,
post=False,
timeout=4,
h3=True,
http_version=dns.asyncquery.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand Down
8 changes: 4 additions & 4 deletions tests/test_doh.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def testDoH3GetRequest(self):
post=False,
timeout=4,
family=family,
h3=True,
http_version=dns.query.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand All @@ -216,7 +216,7 @@ def testDoH3PostRequest(self):
post=True,
timeout=4,
family=family,
h3=True,
http_version=dns.query.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand All @@ -233,7 +233,7 @@ def test_build_url_from_ip(self):
nameserver_ip,
post=False,
timeout=4,
h3=True,
http_version=dns.query.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))
if resolver_v6_addresses:
Expand All @@ -244,7 +244,7 @@ def test_build_url_from_ip(self):
nameserver_ip,
post=False,
timeout=4,
h3=True,
http_version=dns.query.HTTPVersion.H3,
)
self.assertTrue(q.is_response(r))

Expand Down

0 comments on commit 3238267

Please sign in to comment.