Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to specify multiple IP addresses for resolver overrides #1622

Merged
merged 1 commit into from Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/async_impl/client.rs
Expand Up @@ -120,7 +120,7 @@ struct Config {
trust_dns: bool,
error: Option<crate::Error>,
https_only: bool,
dns_overrides: HashMap<String, SocketAddr>,
dns_overrides: HashMap<String, Vec<SocketAddr>>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -1314,16 +1314,30 @@ impl ClientBuilder {
self
}

/// Override DNS resolution for specific domains to particular IP addresses.
/// Override DNS resolution for specific domains to a particular IP address.
///
/// Warning
///
/// Since the DNS protocol has no notion of ports, if you wish to send
/// traffic to a particular port you must include this port in the URL
/// itself, any port in the overridden addr will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder {
self.config.dns_overrides.insert(domain.to_string(), addr);
pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder {
self.resolve_to_addrs(domain, &[addr])
}

/// Override DNS resolution for specific domains to particular IP addresses.
///
/// Warning
///
/// Since the DNS protocol has no notion of ports, if you wish to send
/// traffic to a particular port you must include this port in the URL
/// itself, any port in the overridden addresses will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder {
self.config
.dns_overrides
.insert(domain.to_string(), addrs.to_vec());
self
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/blocking/client.rs
Expand Up @@ -757,7 +757,7 @@ impl ClientBuilder {
self.with_inner(|inner| inner.https_only(enabled))
}

/// Override DNS resolution for specific domains to particular IP addresses.
/// Override DNS resolution for specific domains to a particular IP address.
///
/// Warning
///
Expand All @@ -766,7 +766,19 @@ impl ClientBuilder {
/// itself, any port in the overridden addr will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder {
self.with_inner(|inner| inner.resolve(domain, addr))
self.resolve_to_addrs(domain, &[addr])
}

/// Override DNS resolution for specific domains to particular IP addresses.
///
/// Warning
///
/// Since the DNS protocol has no notion of ports, if you wish to send
/// traffic to a particular port you must include this port in the URL
/// itself, any port in the overridden addresses will be ignored and traffic sent
/// to the conventional port for the given scheme (e.g. 80 for http).
pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder {
self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs))
}

// private
Expand Down
16 changes: 8 additions & 8 deletions src/connect.rs
Expand Up @@ -46,7 +46,7 @@ impl HttpConnector {
Self::Gai(hyper::client::HttpConnector::new())
}

pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
Expand All @@ -64,7 +64,7 @@ impl HttpConnector {

#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, SocketAddr>,
overrides: HashMap<String, Vec<SocketAddr>>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
Expand Down Expand Up @@ -994,7 +994,7 @@ where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;
type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
Expand All @@ -1010,11 +1010,11 @@ where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: Arc<HashMap<String, SocketAddr>>,
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
}

impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
Expand All @@ -1027,12 +1027,12 @@ where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
futures_util::future::Ready<
Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>,
>,
>;

Expand All @@ -1044,7 +1044,7 @@ where
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
std::iter::once(dest.to_owned()),
dest.clone().into_iter(),
)));
Either::Right(fut)
}
Expand Down
70 changes: 70 additions & 0 deletions tests/client.rs
Expand Up @@ -190,6 +190,40 @@ async fn overridden_dns_resolution_with_gai() {
assert_eq!("Hello", text);
}

#[tokio::test]
async fn overridden_dns_resolution_with_gai_multiple() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
// the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs
// algorithm decide which address to use.
let client = reqwest::Client::builder()
.resolve_to_addrs(
overridden_domain,
&[
std::net::SocketAddr::new(
std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
server.addr().port(),
),
server.addr(),
],
)
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(feature = "trust-dns")]
#[tokio::test]
async fn overridden_dns_resolution_with_trust_dns() {
Expand All @@ -215,6 +249,42 @@ async fn overridden_dns_resolution_with_trust_dns() {
assert_eq!("Hello", text);
}

#[cfg(feature = "trust-dns")]
#[tokio::test]
async fn overridden_dns_resolution_with_trust_dns_multiple() {
let _ = env_logger::builder().is_test(true).try_init();
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let overridden_domain = "rust-lang.org";
let url = format!(
"http://{}:{}/domain_override",
overridden_domain,
server.addr().port()
);
// the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs
// algorithm decide which address to use.
let client = reqwest::Client::builder()
.resolve_to_addrs(
overridden_domain,
&[
std::net::SocketAddr::new(
std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
server.addr().port(),
),
server.addr(),
],
)
.trust_dns(true)
.build()
.expect("client builder");
let req = client.get(&url);
let res = req.send().await.expect("request");

assert_eq!(res.status(), reqwest::StatusCode::OK);
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
}

#[cfg(any(feature = "native-tls", feature = "__rustls",))]
#[test]
fn use_preconfigured_tls_with_bogus_backend() {
Expand Down