Skip to content

Commit

Permalink
support for custom DNS resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
irrelevelephant committed Oct 14, 2022
1 parent 110c3ae commit bae9f42
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 274 deletions.
41 changes: 28 additions & 13 deletions src/async_impl/client.rs
Expand Up @@ -13,7 +13,7 @@ use http::header::{
};
use http::uri::Scheme;
use http::Uri;
use hyper::client::ResponseFuture;
use hyper::client::{HttpConnector, ResponseFuture};
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project;
Expand All @@ -28,7 +28,7 @@ use super::decoder::Accepts;
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
use crate::connect::{Connector, HttpConnector};
use crate::connect::Connector;
#[cfg(feature = "cookies")]
use crate::cookie;
use crate::error;
Expand All @@ -41,6 +41,9 @@ use crate::Certificate;
#[cfg(any(feature = "native-tls", feature = "__rustls"))]
use crate::Identity;
use crate::{IntoUrl, Method, Proxy, StatusCode, Url};
use crate::dns::{Resolve, DynResolver, DnsResolverWithOverrides, gai::GaiResolver};
#[cfg(feature = "trust-dns")]
use crate::dns::trust_dns::TrustDnsResolver;

/// An asynchronous `Client` to make Requests with.
///
Expand Down Expand Up @@ -121,6 +124,7 @@ struct Config {
error: Option<crate::Error>,
https_only: bool,
dns_overrides: HashMap<String, Vec<SocketAddr>>,
dns_resolver: Option<Arc<dyn Resolve>>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -188,6 +192,7 @@ impl ClientBuilder {
cookie_store: None,
https_only: false,
dns_overrides: HashMap::new(),
dns_resolver: None,
},
}
}
Expand Down Expand Up @@ -217,25 +222,25 @@ impl ClientBuilder {
headers.get(USER_AGENT).cloned()
}

let http = match config.trust_dns {
let mut resolver: Arc<dyn Resolve> = match config.trust_dns {
false => {
if config.dns_overrides.is_empty() {
HttpConnector::new_gai()
} else {
HttpConnector::new_gai_with_overrides(config.dns_overrides)
}
Arc::new(GaiResolver::new())
}
#[cfg(feature = "trust-dns")]
true => {
if config.dns_overrides.is_empty() {
HttpConnector::new_trust_dns()?
} else {
HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)?
}
Arc::new(TrustDnsResolver::new()
.map_err(crate::error::builder)?)
}
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
if let Some(dns_resolver) = config.dns_resolver {
resolver = dns_resolver;
}
if !config.dns_overrides.is_empty() {
resolver = Arc::new(DnsResolverWithOverrides::new(resolver, config.dns_overrides));
}
let http = HttpConnector::new_with_resolver(DynResolver::new(resolver));

#[cfg(feature = "__tls")]
match config.tls {
Expand Down Expand Up @@ -1340,6 +1345,16 @@ impl ClientBuilder {
.insert(domain.to_string(), addrs.to_vec());
self
}

/// Override the DNS resolver implementation.
///
/// Pass an `Arc` wrapping a trait object implementing `Resolve`.
/// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will
/// still be applied on top of this resolver.
pub fn dns_resolver(mut self, resolver: Arc<dyn Resolve>) -> ClientBuilder {
self.config.dns_resolver = Some(resolver);
self
}
}

type HyperClient = hyper::Client<Connector, super::body::ImplStream>;
Expand Down
234 changes: 4 additions & 230 deletions src/connect.rs
@@ -1,10 +1,8 @@
use futures_util::future::Either;
#[cfg(feature = "__tls")]
use http::header::HeaderValue;
use http::uri::{Authority, Scheme};
use http::Uri;
use hyper::client::connect::{
dns::{GaiResolver, Name},
Connected, Connection,
};
use hyper::service::Service;
Expand All @@ -13,150 +11,23 @@ use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io::IoSlice;
use std::io::{self, IoSlice};
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{collections::HashMap, io};
use std::{future::Future, net::SocketAddr};
use std::future::Future;

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
#[cfg(feature = "__rustls")]
use self::rustls_tls_conn::RustlsTlsConn;
#[cfg(feature = "trust-dns")]
use crate::dns::TrustDnsResolver;
use crate::dns::DynResolver;
use crate::error::BoxError;
use crate::proxy::{Proxy, ProxyScheme};

#[derive(Clone)]
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
#[cfg(feature = "trust-dns")]
TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
#[cfg(feature = "trust-dns")]
TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
}

impl HttpConnector {
pub(crate) fn new_gai() -> Self {
Self::Gai(hyper::client::HttpConnector::new())
}

pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
overridden_resolver,
))
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDns)
.map_err(crate::error::builder)
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, Vec<SocketAddr>>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDnsWithOverrides)
.map_err(crate::error::builder)
}
}

macro_rules! impl_http_connector {
($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => {
#[allow(dead_code)]
impl HttpConnector {
$(
fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
match self {
Self::Gai(resolver) => resolver.$name($($par_name),*),
Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
}
}
)+
}
};
}

impl_http_connector! {
fn set_local_address(&mut self, addr: Option<IpAddr>);
fn enforce_http(&mut self, is_enforced: bool);
fn set_nodelay(&mut self, nodelay: bool);
fn set_keepalive(&mut self, dur: Option<Duration>);
}

impl Service<Uri> for HttpConnector {
type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
#[cfg(feature = "trust-dns")]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
>
>;
#[cfg(not(feature = "trust-dns"))]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Gai(resolver) => resolver.poll_ready(cx),
Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
}
}

fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => {
Either::Right(Either::Right(resolver.call(dst)))
}
}
}
}
pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;

#[derive(Clone)]
pub(crate) struct Connector {
Expand Down Expand Up @@ -960,103 +831,6 @@ mod socks {
}
}

pub(crate) mod itertools {
pub(crate) enum Either<A, B> {
Left(A),
Right(B),
}

impl<A, B> Iterator for Either<A, B>
where
A: Iterator,
B: Iterator<Item = <A as Iterator>::Item>,
{
type Item = <A as Iterator>::Item;

fn next(&mut self) -> Option<Self::Item> {
match self {
Either::Left(a) => a.next(),
Either::Right(b) => b.next(),
}
}
}
}

pin_project! {
pub(crate) struct WrappedResolverFuture<Fut> {
#[pin]
fut: Fut,
}
}

impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.fut
.poll(cx)
.map(|result| result.map(itertools::Either::Left))
}
}

#[derive(Clone)]
pub(crate) struct DnsResolverWithOverrides<Resolver>
where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
}

impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
}
}
}

impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
futures_util::future::Ready<
Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.dns_resolver.poll_ready(cx)
}

fn call(&mut self, name: Name) -> Self::Future {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
dest.clone().into_iter(),
)));
Either::Right(fut)
}
None => {
let resolver_fut = self.dns_resolver.call(name);
let y = WrappedResolverFuture { fut: resolver_fut };
Either::Left(y)
}
}
}
}

mod verbose {
use hyper::client::connect::{Connected, Connection};
use std::fmt;
Expand Down
30 changes: 30 additions & 0 deletions src/dns/gai.rs
@@ -0,0 +1,30 @@
use futures_util::future::FutureExt;
use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name};
use hyper::service::Service;

use crate::BoxError;
use crate::dns::{Addrs, Resolve, Resolving};

#[derive(Debug)]
pub struct GaiResolver(HyperGaiResolver);

impl GaiResolver {
pub fn new() -> Self {
Self(HyperGaiResolver::new())
}
}

impl Default for GaiResolver {
fn default() -> Self {
GaiResolver::new()
}
}

impl Resolve for GaiResolver {
fn resolve(&self, name: Name) -> Resolving {
let this = &mut self.0.clone();
Box::pin(Service::<Name>::call(this, name)
.map(|result| result.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> BoxError { Box::new(err) })))
}
}

0 comments on commit bae9f42

Please sign in to comment.