Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for custom DNS resolution #1653

Merged
merged 1 commit into from Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/async_impl/client.rs
Expand Up @@ -13,7 +13,7 @@ use http::header::{
};
use http::uri::Scheme;
use http::Uri;
use hyper::client::ResponseFuture;
use hyper::client::{HttpConnector, ResponseFuture};
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project;
Expand All @@ -28,9 +28,12 @@ use super::decoder::Accepts;
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
use crate::connect::{Connector, HttpConnector};
use crate::connect::Connector;
#[cfg(feature = "cookies")]
use crate::cookie;
#[cfg(feature = "trust-dns")]
use crate::dns::trust_dns::TrustDnsResolver;
use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve};
use crate::error;
use crate::into_url::{expect_uri, try_uri};
use crate::redirect::{self, remove_sensitive_headers};
Expand Down Expand Up @@ -121,6 +124,7 @@ struct Config {
error: Option<crate::Error>,
https_only: bool,
dns_overrides: HashMap<String, Vec<SocketAddr>>,
dns_resolver: Option<Arc<dyn Resolve>>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -188,6 +192,7 @@ impl ClientBuilder {
cookie_store: None,
https_only: false,
dns_overrides: HashMap::new(),
dns_resolver: None,
},
}
}
Expand Down Expand Up @@ -217,25 +222,23 @@ impl ClientBuilder {
headers.get(USER_AGENT).cloned()
}

let http = match config.trust_dns {
false => {
if config.dns_overrides.is_empty() {
HttpConnector::new_gai()
} else {
HttpConnector::new_gai_with_overrides(config.dns_overrides)
}
}
let mut resolver: Arc<dyn Resolve> = match config.trust_dns {
false => Arc::new(GaiResolver::new()),
#[cfg(feature = "trust-dns")]
true => {
if config.dns_overrides.is_empty() {
HttpConnector::new_trust_dns()?
} else {
HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)?
}
}
true => Arc::new(TrustDnsResolver::new().map_err(crate::error::builder)?),
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
if let Some(dns_resolver) = config.dns_resolver {
resolver = dns_resolver;
}
if !config.dns_overrides.is_empty() {
resolver = Arc::new(DnsResolverWithOverrides::new(
resolver,
config.dns_overrides,
));
}
let http = HttpConnector::new_with_resolver(DynResolver::new(resolver));

#[cfg(feature = "__tls")]
match config.tls {
Expand Down Expand Up @@ -1340,6 +1343,16 @@ impl ClientBuilder {
.insert(domain.to_string(), addrs.to_vec());
self
}

/// Override the DNS resolver implementation.
///
/// Pass an `Arc` wrapping a trait object implementing `Resolve`.
/// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will
/// still be applied on top of this resolver.
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
self.config.dns_resolver = Some(resolver as _);
self
}
}

type HyperClient = hyper::Client<Connector, super::body::ImplStream>;
Expand Down
238 changes: 5 additions & 233 deletions src/connect.rs
@@ -1,162 +1,31 @@
use futures_util::future::Either;
#[cfg(feature = "__tls")]
use http::header::HeaderValue;
use http::uri::{Authority, Scheme};
use http::Uri;
use hyper::client::connect::{
dns::{GaiResolver, Name},
Connected, Connection,
};
use hyper::client::connect::{Connected, Connection};
use hyper::service::Service;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io::IoSlice;
use std::future::Future;
use std::io::{self, IoSlice};
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{collections::HashMap, io};
use std::{future::Future, net::SocketAddr};

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
#[cfg(feature = "__rustls")]
use self::rustls_tls_conn::RustlsTlsConn;
#[cfg(feature = "trust-dns")]
use crate::dns::TrustDnsResolver;
use crate::dns::DynResolver;
use crate::error::BoxError;
use crate::proxy::{Proxy, ProxyScheme};

#[derive(Clone)]
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
#[cfg(feature = "trust-dns")]
TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
#[cfg(feature = "trust-dns")]
TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
}

impl HttpConnector {
pub(crate) fn new_gai() -> Self {
Self::Gai(hyper::client::HttpConnector::new())
}

pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
overridden_resolver,
))
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDns)
.map_err(crate::error::builder)
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, Vec<SocketAddr>>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDnsWithOverrides)
.map_err(crate::error::builder)
}
}

macro_rules! impl_http_connector {
($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => {
#[allow(dead_code)]
impl HttpConnector {
$(
fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
match self {
Self::Gai(resolver) => resolver.$name($($par_name),*),
Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
}
}
)+
}
};
}

impl_http_connector! {
fn set_local_address(&mut self, addr: Option<IpAddr>);
fn enforce_http(&mut self, is_enforced: bool);
fn set_nodelay(&mut self, nodelay: bool);
fn set_keepalive(&mut self, dur: Option<Duration>);
}

impl Service<Uri> for HttpConnector {
type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
#[cfg(feature = "trust-dns")]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
>
>;
#[cfg(not(feature = "trust-dns"))]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Gai(resolver) => resolver.poll_ready(cx),
Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
}
}

fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => {
Either::Right(Either::Right(resolver.call(dst)))
}
}
}
}
pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;

#[derive(Clone)]
pub(crate) struct Connector {
Expand Down Expand Up @@ -960,103 +829,6 @@ mod socks {
}
}

pub(crate) mod itertools {
pub(crate) enum Either<A, B> {
Left(A),
Right(B),
}

impl<A, B> Iterator for Either<A, B>
where
A: Iterator,
B: Iterator<Item = <A as Iterator>::Item>,
{
type Item = <A as Iterator>::Item;

fn next(&mut self) -> Option<Self::Item> {
match self {
Either::Left(a) => a.next(),
Either::Right(b) => b.next(),
}
}
}
}

pin_project! {
pub(crate) struct WrappedResolverFuture<Fut> {
#[pin]
fut: Fut,
}
}

impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.fut
.poll(cx)
.map(|result| result.map(itertools::Either::Left))
}
}

#[derive(Clone)]
pub(crate) struct DnsResolverWithOverrides<Resolver>
where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
}

impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
}
}
}

impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
futures_util::future::Ready<
Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.dns_resolver.poll_ready(cx)
}

fn call(&mut self, name: Name) -> Self::Future {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
dest.clone().into_iter(),
)));
Either::Right(fut)
}
None => {
let resolver_fut = self.dns_resolver.call(name);
let y = WrappedResolverFuture { fut: resolver_fut };
Either::Left(y)
}
}
}
}

mod verbose {
use hyper::client::connect::{Connected, Connection};
use std::fmt;
Expand Down
32 changes: 32 additions & 0 deletions src/dns/gai.rs
@@ -0,0 +1,32 @@
use futures_util::future::FutureExt;
use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name};
use hyper::service::Service;

use crate::dns::{Addrs, Resolve, Resolving};
use crate::error::BoxError;

#[derive(Debug)]
pub struct GaiResolver(HyperGaiResolver);

impl GaiResolver {
pub fn new() -> Self {
Self(HyperGaiResolver::new())
}
}

impl Default for GaiResolver {
fn default() -> Self {
GaiResolver::new()
}
}

impl Resolve for GaiResolver {
fn resolve(&self, name: Name) -> Resolving {
let this = &mut self.0.clone();
Box::pin(Service::<Name>::call(this, name).map(|result| {
result
.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> BoxError { Box::new(err) })
}))
}
}