Skip to content

Commit

Permalink
Allow overriding of DNS resolution to specified IP addresses(#561)
Browse files Browse the repository at this point in the history
This change allows users to bypass the selected DNS resolver for
specific domains. The allows, for example, to make calls to a local TLS
server by rerouting a given domain to 127.0.0.1.

The approach I've taken for the design is to wrap the resolver in an
outer service. This leads to a fair amount of boilerplate code mainly to
be able to explain the typing to the compiler. The actual business logic
is very simple for the number of lines involved.

Closes #561
  • Loading branch information
campbellC committed Jun 1, 2021
1 parent bbeb1ed commit 920d721
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 16 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ url = "2.2"
bytes = "1.0"
serde = "1.0"
serde_urlencoded = "0.7"
itertools = "0.10.0"

# Optional deps...

Expand Down
39 changes: 36 additions & 3 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(any(feature = "native-tls", feature = "__rustls",))]
use std::any::Any;
use std::convert::TryInto;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, convert::TryInto, net::SocketAddr};
use std::{fmt, str};

use bytes::Bytes;
Expand Down Expand Up @@ -107,6 +107,7 @@ struct Config {
trust_dns: bool,
error: Option<crate::Error>,
https_only: bool,
dns_overrides: HashMap<String, SocketAddr>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -164,6 +165,7 @@ impl ClientBuilder {
#[cfg(feature = "cookies")]
cookie_store: None,
https_only: false,
dns_overrides: HashMap::new(),
},
}
}
Expand Down Expand Up @@ -194,9 +196,21 @@ impl ClientBuilder {
}

let http = match config.trust_dns {
false => HttpConnector::new_gai(),
false => {
if config.dns_overrides.is_empty() {
HttpConnector::new_gai()
} else {
HttpConnector::new_gai_with_overrides(config.dns_overrides)
}
}
#[cfg(feature = "trust-dns")]
true => HttpConnector::new_trust_dns()?,
true => {
if config.dns_overrides.is_empty() {
HttpConnector::new_trust_dns()?
} else {
HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)?
}
}
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
Expand Down Expand Up @@ -1028,6 +1042,21 @@ impl ClientBuilder {
self.config.https_only = enabled;
self
}

/// Override DNS resolution for specific domains to particular IP addresses.
///
/// Warning
///
/// Since the DNS protocol has no notion of ports, if you wish to send
/// traffic to a particular port you must include this port in the URL
/// itself, any port in the overridden addr will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve(mut self, domain: &str, addr: impl Into<SocketAddr>) -> ClientBuilder {
self.config
.dns_overrides
.insert(domain.to_string(), addr.into());
self
}
}

type HyperClient = hyper::Client<Connector, super::body::ImplStream>;
Expand Down Expand Up @@ -1341,6 +1370,10 @@ impl Config {
{
f.field("tls_backend", &self.tls);
}

if !self.dns_overrides.is_empty() {
f.field("dns_overrides", &self.dns_overrides);
}
}
}

Expand Down
155 changes: 142 additions & 13 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ use futures_util::future::Either;
use http::header::HeaderValue;
use http::uri::{Authority, Scheme};
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use hyper::client::connect::{
dns::{GaiResolver, Name},
Connected, Connection,
};
use hyper::service::Service;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::IoSlice;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{collections::HashMap, io};
use std::{future::Future, net::SocketAddr};

#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
Expand All @@ -31,22 +34,44 @@ use crate::proxy::{Proxy, ProxyScheme};
#[derive(Clone)]
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
#[cfg(feature = "trust-dns")]
TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
#[cfg(feature = "trust-dns")]
TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
}

impl HttpConnector {
pub(crate) fn new_gai() -> Self {
Self::Gai(hyper::client::HttpConnector::new())
}

pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
overridden_resolver,
))
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDns)
.map_err(crate::error::builder)
}

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, SocketAddr>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDnsWithOverrides)
.map_err(crate::error::builder)
}
}

macro_rules! impl_http_connector {
Expand All @@ -57,8 +82,11 @@ macro_rules! impl_http_connector {
fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
match self {
Self::Gai(resolver) => resolver.$name($($par_name),*),
Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
}
}
)+
Expand All @@ -77,29 +105,55 @@ impl Service<Uri> for HttpConnector {
type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
#[cfg(feature = "trust-dns")]
type Future = Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
>;
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
>
>;
#[cfg(not(feature = "trust-dns"))]
type Future = Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>;
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Gai(resolver) => resolver.poll_ready(cx),
Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
}
}

fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Self::Gai(resolver) => Either::Left(resolver.call(dst)),
Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => Either::Right(resolver.call(dst)),
Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => {
Either::Right(Either::Right(resolver.call(dst)))
}
}
}
}
Expand Down Expand Up @@ -889,6 +943,81 @@ mod socks {
}
}

pin_project! {
pub(crate) struct WrappedResolverFuture<Fut> {
#[pin]
fut: Fut,
}
}

impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.fut
.poll(cx)
.map(|result| result.map(itertools::Either::Left))
}
}

#[derive(Clone)]
pub(crate) struct DnsResolverWithOverrides<Resolver>
where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: HashMap<String, SocketAddr>,
}

impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides,
}
}
}

impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
std::future::Ready<
Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
>,
>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.dns_resolver.poll_ready(cx)
}

fn call(&mut self, name: Name) -> Self::Future {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = std::future::ready(Ok(itertools::Either::Right(std::iter::once(
dest.to_owned(),
))));
Either::Right(fut)
}
None => {
let resolver_fut = self.dns_resolver.call(name);
let y = WrappedResolverFuture { fut: resolver_fut };
Either::Left(y)
}
}
}
}

mod verbose {
use hyper::client::connect::{Connected, Connection};
use std::fmt;
Expand Down
48 changes: 48 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,54 @@ async fn body_pipe_response() {
assert_eq!(res2.status(), reqwest::StatusCode::OK);
}

#[tokio::test]
async fn overridden_dns_resolution_with_gai() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
let client = reqwest::Client::builder()
.resolve(overridden_domain, server.addr())
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(feature = "trust-dns")]
#[tokio::test]
async fn overridden_dns_resolution_with_trust_dns() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
let client = reqwest::Client::builder()
.resolve(overridden_domain, server.addr())
.trust_dns(true)
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(any(feature = "native-tls", feature = "__rustls",))]
#[test]
fn use_preconfigured_tls_with_bogus_backend() {
Expand Down

0 comments on commit 920d721

Please sign in to comment.