Skip to content

Commit

Permalink
feat(client): change DNS Resolver to resolve to SocketAddrs (hyperium…
Browse files Browse the repository at this point in the history
…#2346)

The DNS resolver part of `HttpConnector` used to require resolving to
`IpAddr`s, and this changes it so that they resolve to `SocketAddr`s.
The main benefit here is allowing for resolvers to set the IPv6 zone ID
when resolving, but it also just more closely matches
`std::net::ToSocketAddrs`.

Closes hyperium#1937

BREAKING CHANGE: Custom resolvers used with `HttpConnector` must change
  to resolving to an iterator of `SocketAddr`s instead of `IpAddr`s.
  • Loading branch information
seanmonstar authored and Benxiang Ge committed Jul 26, 2021
1 parent 01343df commit 0446c12
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 39 deletions.
64 changes: 32 additions & 32 deletions src/client/connect/dns.rs
Expand Up @@ -9,21 +9,21 @@
//! # Resolvers are `Service`s
//!
//! A resolver is just a
//! `Service<Name, Response = impl Iterator<Item = IpAddr>>`.
//! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`.
//!
//! A simple resolver that ignores the name and always returns a specific
//! address:
//!
//! ```rust,ignore
//! use std::{convert::Infallible, iter, net::IpAddr};
//! use std::{convert::Infallible, iter, net::SocketAddr};
//!
//! let resolver = tower::service_fn(|_name| async {
//! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1])))
//! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080))))
//! });
//! ```
use std::error::Error;
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::pin::Pin;
use std::str::FromStr;
use std::task::{self, Poll};
Expand All @@ -48,12 +48,12 @@ pub struct GaiResolver {

/// An iterator of IP addresses returned from `getaddrinfo`.
pub struct GaiAddrs {
inner: IpAddrs,
inner: SocketAddrs,
}

/// A future to resolve a name returned by `GaiResolver`.
pub struct GaiFuture {
inner: JoinHandle<Result<IpAddrs, io::Error>>,
inner: JoinHandle<Result<SocketAddrs, io::Error>>,
}

impl Name {
Expand Down Expand Up @@ -121,7 +121,7 @@ impl Service<Name> for GaiResolver {
debug!("resolving host={:?}", name.host);
(&*name.host, 0)
.to_socket_addrs()
.map(|i| IpAddrs { iter: i })
.map(|i| SocketAddrs { iter: i })
});

GaiFuture { inner: blocking }
Expand Down Expand Up @@ -159,10 +159,10 @@ impl fmt::Debug for GaiFuture {
}

impl Iterator for GaiAddrs {
type Item = IpAddr;
type Item = SocketAddr;

fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|sa| sa.ip())
self.inner.next()
}
}

Expand All @@ -172,47 +172,47 @@ impl fmt::Debug for GaiAddrs {
}
}

pub(super) struct IpAddrs {
pub(super) struct SocketAddrs {
iter: vec::IntoIter<SocketAddr>,
}

impl IpAddrs {
impl SocketAddrs {
pub(super) fn new(addrs: Vec<SocketAddr>) -> Self {
IpAddrs {
SocketAddrs {
iter: addrs.into_iter(),
}
}

pub(super) fn try_parse(host: &str, port: u16) -> Option<IpAddrs> {
pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> {
if let Ok(addr) = host.parse::<Ipv4Addr>() {
let addr = SocketAddrV4::new(addr, port);
return Some(IpAddrs {
return Some(SocketAddrs {
iter: vec![SocketAddr::V4(addr)].into_iter(),
});
}
let host = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(addr) = host.parse::<Ipv6Addr>() {
let addr = SocketAddrV6::new(addr, port, 0, 0);
return Some(IpAddrs {
return Some(SocketAddrs {
iter: vec![SocketAddr::V6(addr)].into_iter(),
});
}
None
}

#[inline]
fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> IpAddrs {
IpAddrs::new(self.iter.filter(predicate).collect())
fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs {
SocketAddrs::new(self.iter.filter(predicate).collect())
}

pub(super) fn split_by_preference(
self,
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
) -> (IpAddrs, IpAddrs) {
) -> (SocketAddrs, SocketAddrs) {
match (local_addr_ipv4, local_addr_ipv6) {
(Some(_), None) => (self.filter(SocketAddr::is_ipv4), IpAddrs::new(vec![])),
(None, Some(_)) => (self.filter(SocketAddr::is_ipv6), IpAddrs::new(vec![])),
(Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])),
(None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])),
_ => {
let preferring_v6 = self
.iter
Expand All @@ -225,7 +225,7 @@ impl IpAddrs {
.iter
.partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);

(IpAddrs::new(preferred), IpAddrs::new(fallback))
(SocketAddrs::new(preferred), SocketAddrs::new(fallback))
}
}
}
Expand All @@ -239,7 +239,7 @@ impl IpAddrs {
}
}

impl Iterator for IpAddrs {
impl Iterator for SocketAddrs {
type Item = SocketAddr;
#[inline]
fn next(&mut self) -> Option<SocketAddr> {
Expand Down Expand Up @@ -312,13 +312,13 @@ impl Future for TokioThreadpoolGaiFuture {
*/

mod sealed {
use super::{IpAddr, Name};
use super::{SocketAddr, Name};
use crate::common::{task, Future, Poll};
use tower_service::Service;

// "Trait alias" for `Service<Name, Response = Addrs>`
pub trait Resolve {
type Addrs: Iterator<Item = IpAddr>;
type Addrs: Iterator<Item = SocketAddr>;
type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
type Future: Future<Output = Result<Self::Addrs, Self::Error>>;

Expand All @@ -329,7 +329,7 @@ mod sealed {
impl<S> Resolve for S
where
S: Service<Name>,
S::Response: Iterator<Item = IpAddr>,
S::Response: Iterator<Item = SocketAddr>,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Addrs = S::Response;
Expand Down Expand Up @@ -366,42 +366,42 @@ mod tests {
let v4_addr = (ip_v4, 80).into();
let v6_addr = (ip_v6, 80).into();

let (mut preferred, mut fallback) = IpAddrs {
let (mut preferred, mut fallback) = SocketAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(None, None);
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.next().unwrap().is_ipv6());

let (mut preferred, mut fallback) = IpAddrs {
let (mut preferred, mut fallback) = SocketAddrs {
iter: vec![v6_addr, v4_addr].into_iter(),
}
.split_by_preference(None, None);
assert!(preferred.next().unwrap().is_ipv6());
assert!(fallback.next().unwrap().is_ipv4());

let (mut preferred, mut fallback) = IpAddrs {
let (mut preferred, mut fallback) = SocketAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(Some(ip_v4), Some(ip_v6));
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.next().unwrap().is_ipv6());

let (mut preferred, mut fallback) = IpAddrs {
let (mut preferred, mut fallback) = SocketAddrs {
iter: vec![v6_addr, v4_addr].into_iter(),
}
.split_by_preference(Some(ip_v4), Some(ip_v6));
assert!(preferred.next().unwrap().is_ipv6());
assert!(fallback.next().unwrap().is_ipv4());

let (mut preferred, fallback) = IpAddrs {
let (mut preferred, fallback) = SocketAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(Some(ip_v4), None);
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.is_empty());

let (mut preferred, fallback) = IpAddrs {
let (mut preferred, fallback) = SocketAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(None, Some(ip_v6));
Expand All @@ -422,7 +422,7 @@ mod tests {
let dst = ::http::Uri::from_static("http://[::1]:8080/");

let mut addrs =
IpAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port"))
SocketAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port"))
.expect("try_parse");

let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected");
Expand Down
17 changes: 10 additions & 7 deletions src/client/connect/http.rs
Expand Up @@ -321,14 +321,17 @@ where

// If the host is already an IP addr (v4 or v6),
// skip resolving the dns and start connecting right away.
let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) {
let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
addrs
} else {
let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
.await
.map_err(ConnectError::dns)?;
let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect();
dns::IpAddrs::new(addrs)
let addrs = addrs.map(|mut addr| {
addr.set_port(port);
addr
}).collect();
dns::SocketAddrs::new(addrs)
};

let c = ConnectingTcp::new(addrs, config);
Expand Down Expand Up @@ -457,7 +460,7 @@ struct ConnectingTcp<'a> {
}

impl<'a> ConnectingTcp<'a> {
fn new(remote_addrs: dns::IpAddrs, config: &'a Config) -> Self {
fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
let (preferred_addrs, fallback_addrs) = remote_addrs
.split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
Expand Down Expand Up @@ -493,12 +496,12 @@ struct ConnectingTcpFallback {
}

struct ConnectingTcpRemote {
addrs: dns::IpAddrs,
addrs: dns::SocketAddrs,
connect_timeout: Option<Duration>,
}

impl ConnectingTcpRemote {
fn new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self {
fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));

Self {
Expand Down Expand Up @@ -920,7 +923,7 @@ mod tests {
send_buffer_size: None,
recv_buffer_size: None,
};
let connecting_tcp = ConnectingTcp::new(dns::IpAddrs::new(addrs), &cfg);
let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
let start = Instant::now();
Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
})
Expand Down

0 comments on commit 0446c12

Please sign in to comment.