Skip to content

Commit

Permalink
Allow overriding of DNS resolution to specified IP addresses(#561) (#…
Browse files Browse the repository at this point in the history
…1277)

This change allows users to bypass the selected DNS resolver for
specific domains. The allows, for example, to make calls to a local TLS
server by rerouting a given domain to 127.0.0.1.

The approach I've taken for the design is to wrap the resolver in an
outer service. This leads to a fair amount of boilerplate code mainly to
be able to explain the typing to the compiler. The actual business logic
is very simple for the number of lines involved.

Closes #561
  • Loading branch information
campbellC committed Jun 16, 2021
1 parent c4388fc commit 8e5af45
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 16 deletions.
37 changes: 34 additions & 3 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(any(feature = "native-tls", feature = "__rustls",))]
use std::any::Any;
use std::convert::TryInto;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, convert::TryInto, net::SocketAddr};
use std::{fmt, str};

use bytes::Bytes;
Expand Down Expand Up @@ -107,6 +107,7 @@ struct Config {
trust_dns: bool,
error: Option<crate::Error>,
https_only: bool,
dns_overrides: HashMap<String, SocketAddr>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -164,6 +165,7 @@ impl ClientBuilder {
#[cfg(feature = "cookies")]
cookie_store: None,
https_only: false,
dns_overrides: HashMap::new(),
},
}
}
Expand Down Expand Up @@ -194,9 +196,21 @@ impl ClientBuilder {
}

let http = match config.trust_dns {
false => HttpConnector::new_gai(),
false => {
if config.dns_overrides.is_empty() {
HttpConnector::new_gai()
} else {
HttpConnector::new_gai_with_overrides(config.dns_overrides)
}
}
#[cfg(feature = "trust-dns")]
true => HttpConnector::new_trust_dns()?,
true => {
if config.dns_overrides.is_empty() {
HttpConnector::new_trust_dns()?
} else {
HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)?
}
}
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
Expand Down Expand Up @@ -1037,6 +1051,19 @@ impl ClientBuilder {
self.config.https_only = enabled;
self
}

/// Override DNS resolution for specific domains to particular IP addresses.
///
/// Warning
///
/// Since the DNS protocol has no notion of ports, if you wish to send
/// traffic to a particular port you must include this port in the URL
/// itself, any port in the overridden addr will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder {
self.config.dns_overrides.insert(domain.to_string(), addr);
self
}
}

type HyperClient = hyper::Client<Connector, super::body::ImplStream>;
Expand Down Expand Up @@ -1350,6 +1377,10 @@ impl Config {
{
f.field("tls_backend", &self.tls);
}

if !self.dns_overrides.is_empty() {
f.field("dns_overrides", &self.dns_overrides);
}
}
}

Expand Down
177 changes: 164 additions & 13 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ use futures_util::future::Either;
use http::header::HeaderValue;
use http::uri::{Authority, Scheme};
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use hyper::client::connect::{
dns::{GaiResolver, Name},
Connected, Connection,
};
use hyper::service::Service;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::IoSlice;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{collections::HashMap, io};
use std::{future::Future, net::SocketAddr};

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
Expand All @@ -31,22 +34,44 @@ use crate::proxy::{Proxy, ProxyScheme};
#[derive(Clone)]
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
#[cfg(feature = "trust-dns")]
TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
#[cfg(feature = "trust-dns")]
TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
}

impl HttpConnector {
pub(crate) fn new_gai() -> Self {
Self::Gai(hyper::client::HttpConnector::new())
}

pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
overridden_resolver,
))
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDns)
.map_err(crate::error::builder)
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, SocketAddr>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDnsWithOverrides)
.map_err(crate::error::builder)
}
}

macro_rules! impl_http_connector {
Expand All @@ -57,8 +82,11 @@ macro_rules! impl_http_connector {
fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
match self {
Self::Gai(resolver) => resolver.$name($($par_name),*),
Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
}
}
)+
Expand All @@ -77,29 +105,55 @@ impl Service<Uri> for HttpConnector {
type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
#[cfg(feature = "trust-dns")]
type Future = Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
>;
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
>
>;
#[cfg(not(feature = "trust-dns"))]
type Future = Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>;
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Gai(resolver) => resolver.poll_ready(cx),
Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
}
}

fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Self::Gai(resolver) => Either::Left(resolver.call(dst)),
Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => Either::Right(resolver.call(dst)),
Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => {
Either::Right(Either::Right(resolver.call(dst)))
}
}
}
}
Expand Down Expand Up @@ -908,6 +962,103 @@ mod socks {
}
}

pub(crate) mod itertools {
pub(crate) enum Either<A, B> {
Left(A),
Right(B),
}

impl<A, B> Iterator for Either<A, B>
where
A: Iterator,
B: Iterator<Item = <A as Iterator>::Item>,
{
type Item = <A as Iterator>::Item;

fn next(&mut self) -> Option<Self::Item> {
match self {
Either::Left(a) => a.next(),
Either::Right(b) => b.next(),
}
}
}
}

pin_project! {
pub(crate) struct WrappedResolverFuture<Fut> {
#[pin]
fut: Fut,
}
}

impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.fut
.poll(cx)
.map(|result| result.map(itertools::Either::Left))
}
}

#[derive(Clone)]
pub(crate) struct DnsResolverWithOverrides<Resolver>
where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: Arc<HashMap<String, SocketAddr>>,
}

impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
}
}
}

impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
futures_util::future::Ready<
Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.dns_resolver.poll_ready(cx)
}

fn call(&mut self, name: Name) -> Self::Future {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
std::iter::once(dest.to_owned()),
)));
Either::Right(fut)
}
None => {
let resolver_fut = self.dns_resolver.call(name);
let y = WrappedResolverFuture { fut: resolver_fut };
Either::Left(y)
}
}
}
}

mod verbose {
use hyper::client::connect::{Connected, Connection};
use std::fmt;
Expand Down
48 changes: 48 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,54 @@ async fn body_pipe_response() {
assert_eq!(res2.status(), reqwest::StatusCode::OK);
}

#[tokio::test]
async fn overridden_dns_resolution_with_gai() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
let client = reqwest::Client::builder()
.resolve(overridden_domain, server.addr())
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(feature = "trust-dns")]
#[tokio::test]
async fn overridden_dns_resolution_with_trust_dns() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
let client = reqwest::Client::builder()
.resolve(overridden_domain, server.addr())
.trust_dns(true)
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(any(feature = "native-tls", feature = "__rustls",))]
#[test]
fn use_preconfigured_tls_with_bogus_backend() {
Expand Down

0 comments on commit 8e5af45

Please sign in to comment.