Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): add per IP address connection timeout #1958

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
140 changes: 126 additions & 14 deletions src/client/connect/http.rs
Expand Up @@ -10,7 +10,7 @@ use futures_util::{TryFutureExt, FutureExt};
use net2::TcpBuilder;
use tokio_net::driver::Handle;
use tokio_net::tcp::TcpStream;
use tokio_timer::Delay;
use tokio_timer::{Delay, Timeout};

use crate::common::{Future, Pin, Poll, task};
use super::{Connect, Connected, Destination};
Expand All @@ -31,6 +31,7 @@ type ConnectFuture = Pin<Box<dyn Future<Output = io::Result<TcpStream>> + Send>>
pub struct HttpConnector<R = GaiResolver> {
enforce_http: bool,
handle: Option<Handle>,
connect_timeout: Option<Duration>,
happy_eyeballs_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
local_address: Option<IpAddr>,
Expand Down Expand Up @@ -99,6 +100,7 @@ impl<R> HttpConnector<R> {
HttpConnector {
enforce_http: true,
handle: None,
connect_timeout: None,
happy_eyeballs_timeout: Some(Duration::from_millis(300)),
keep_alive_timeout: None,
local_address: None,
Expand Down Expand Up @@ -166,6 +168,21 @@ impl<R> HttpConnector<R> {
self.local_address = addr;
}

/// Set timeout for each attempt to connect to an IP address.
///
/// If the hostname resolves to multiple IP addresses then this timeout is
/// applied to each individual connection attempt, ensuring that all the
/// addresses are given equal opportunity to respond.
///
/// If `None`, then no timeout is applied by the connector, making it
/// subject to the timeout imposed by the operating system.
///
/// Default is `None`.
#[inline]
pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
self.connect_timeout = dur;
}

/// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
///
/// If hostname resolves to both IPv4 and IPv6 addresses and connection
Expand Down Expand Up @@ -238,6 +255,7 @@ where
HttpConnecting {
state: State::Lazy(self.resolver.clone(), host.into(), self.local_address),
handle: self.handle.clone(),
connect_timeout: self.connect_timeout,
happy_eyeballs_timeout: self.happy_eyeballs_timeout,
keep_alive_timeout: self.keep_alive_timeout,
nodelay: self.nodelay,
Expand Down Expand Up @@ -293,6 +311,7 @@ where
let fut = HttpConnecting {
state: State::Lazy(self.resolver.clone(), host.into(), self.local_address),
handle: self.handle.clone(),
connect_timeout: self.connect_timeout,
happy_eyeballs_timeout: self.happy_eyeballs_timeout,
keep_alive_timeout: self.keep_alive_timeout,
nodelay: self.nodelay,
Expand Down Expand Up @@ -321,6 +340,7 @@ fn invalid_url<R: Resolve>(err: InvalidUrl, handle: &Option<Handle>) -> HttpConn
keep_alive_timeout: None,
nodelay: false,
port: 0,
connect_timeout: None,
happy_eyeballs_timeout: None,
reuse_address: false,
send_buffer_size: None,
Expand Down Expand Up @@ -355,6 +375,7 @@ impl StdError for InvalidUrl {
pub struct HttpConnecting<R: Resolve = GaiResolver> {
state: State<R>,
handle: Option<Handle>,
connect_timeout: Option<Duration>,
happy_eyeballs_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
nodelay: bool,
Expand Down Expand Up @@ -387,7 +408,12 @@ where
// skip resolving the dns and start connecting right away.
if let Some(addrs) = dns::IpAddrs::try_parse(host, me.port) {
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address));
local_addr,
addrs,
me.connect_timeout,
me.happy_eyeballs_timeout,
me.reuse_address,
));
} else {
let name = dns::Name::new(mem::replace(host, String::new()));
state = State::Resolving(resolver.resolve(name), local_addr);
Expand All @@ -401,8 +427,13 @@ where
.collect();
let addrs = dns::IpAddrs::new(addrs);
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address));
},
local_addr,
addrs,
me.connect_timeout,
me.happy_eyeballs_timeout,
me.reuse_address,
));
}
State::Connecting(ref mut c) => {
let sock = ready!(c.poll(cx, &me.handle))?;

Expand Down Expand Up @@ -445,13 +476,15 @@ struct ConnectingTcp {
local_addr: Option<IpAddr>,
preferred: ConnectingTcpRemote,
fallback: Option<ConnectingTcpFallback>,
connect_timeout: Option<Duration>,
reuse_address: bool,
}

impl ConnectingTcp {
fn new(
local_addr: Option<IpAddr>,
remote_addrs: dns::IpAddrs,
connect_timeout: Option<Duration>,
fallback_timeout: Option<Duration>,
reuse_address: bool,
) -> ConnectingTcp {
Expand All @@ -462,6 +495,7 @@ impl ConnectingTcp {
local_addr,
preferred: ConnectingTcpRemote::new(preferred_addrs),
fallback: None,
connect_timeout,
reuse_address,
};
}
Expand All @@ -473,13 +507,15 @@ impl ConnectingTcp {
delay: tokio_timer::delay_for(fallback_timeout),
remote: ConnectingTcpRemote::new(fallback_addrs),
}),
connect_timeout,
reuse_address,
}
} else {
ConnectingTcp {
local_addr,
preferred: ConnectingTcpRemote::new(remote_addrs),
fallback: None,
connect_timeout,
reuse_address,
}
}
Expand Down Expand Up @@ -512,6 +548,7 @@ impl ConnectingTcpRemote {
cx: &mut task::Context<'_>,
local_addr: &Option<IpAddr>,
handle: &Option<Handle>,
connect_timeout: Option<Duration>,
reuse_address: bool,
) -> Poll<io::Result<TcpStream>> {
let mut err = None;
Expand All @@ -528,14 +565,20 @@ impl ConnectingTcpRemote {
err = Some(e);
if let Some(addr) = self.addrs.next() {
debug!("connecting to {}", addr);
*current = connect(&addr, local_addr, handle, reuse_address)?;
*current = connect(&addr, local_addr, handle, connect_timeout, reuse_address)?;
continue;
}
}
}
} else if let Some(addr) = self.addrs.next() {
debug!("connecting to {}", addr);
self.current = Some(connect(&addr, local_addr, handle, reuse_address)?);
self.current = Some(connect(
&addr,
local_addr,
handle,
connect_timeout,
reuse_address,
)?);
continue;
}

Expand All @@ -544,7 +587,13 @@ impl ConnectingTcpRemote {
}
}

fn connect(addr: &SocketAddr, local_addr: &Option<IpAddr>, handle: &Option<Handle>, reuse_address: bool) -> io::Result<ConnectFuture> {
fn connect(
addr: &SocketAddr,
local_addr: &Option<IpAddr>,
handle: &Option<Handle>,
connect_timeout: Option<Duration>,
reuse_address: bool,
) -> io::Result<ConnectFuture> {
let builder = match addr {
&SocketAddr::V4(_) => TcpBuilder::new_v4()?,
&SocketAddr::V6(_) => TcpBuilder::new_v6()?,
Expand Down Expand Up @@ -579,7 +628,12 @@ fn connect(addr: &SocketAddr, local_addr: &Option<IpAddr>, handle: &Option<Handl
let std_tcp = builder.to_tcp_stream()?;

Ok(Box::pin(async move {
TcpStream::connect_std(std_tcp, &addr, &handle).await
let stream = TcpStream::connect_std(std_tcp, &addr, &handle);
if let Some(timeout) = connect_timeout {
Timeout::new(stream, timeout).await?
} else {
stream.await
}
}))

//Ok(Box::pin(TcpStream::connect_std(std_tcp, addr, &handle)))
Expand All @@ -588,14 +642,32 @@ fn connect(addr: &SocketAddr, local_addr: &Option<IpAddr>, handle: &Option<Handl
impl ConnectingTcp {
fn poll(&mut self, cx: &mut task::Context<'_>, handle: &Option<Handle>) -> Poll<io::Result<TcpStream>> {
match self.fallback.take() {
None => self.preferred.poll(cx, &self.local_addr, handle, self.reuse_address),
Some(mut fallback) => match self.preferred.poll(cx, &self.local_addr, handle, self.reuse_address) {
None => self.preferred.poll(
cx,
&self.local_addr,
handle,
self.connect_timeout,
self.reuse_address,
),
Some(mut fallback) => match self.preferred.poll(
cx,
&self.local_addr,
handle,
self.connect_timeout,
self.reuse_address,
) {
Poll::Ready(Ok(stream)) => {
// Preferred successful - drop fallback.
Poll::Ready(Ok(stream))
}
Poll::Pending => match Pin::new(&mut fallback.delay).poll(cx) {
Poll::Ready(()) => match fallback.remote.poll(cx, &self.local_addr, handle, self.reuse_address) {
Poll::Ready(()) => match fallback.remote.poll(
cx,
&self.local_addr,
handle,
self.connect_timeout,
self.reuse_address,
) {
Poll::Ready(Ok(stream)) => {
// Fallback successful - drop current preferred,
// but keep fallback as new preferred.
Expand All @@ -621,7 +693,13 @@ impl ConnectingTcp {
Poll::Ready(Err(_)) => {
// Preferred failed - use fallback as new preferred.
self.preferred = fallback.remote;
self.preferred.poll(cx, &self.local_addr, handle, self.reuse_address)
self.preferred.poll(
cx,
&self.local_addr,
handle,
self.connect_timeout,
self.reuse_address,
)
}
}
}
Expand All @@ -631,6 +709,7 @@ impl ConnectingTcp {
#[cfg(test)]
mod tests {
use std::io;
use std::time::{Duration, Instant};

use tokio::runtime::current_thread::Runtime;
use tokio_net::driver::Handle;
Expand Down Expand Up @@ -689,13 +768,46 @@ mod tests {
});
}

#[test]
fn test_connect_timeout() {
use std::io::ErrorKind;

let mut rt = Runtime::new().unwrap();

// 240.0.0.1 is reserved in IPv4 and is effectively a black hole.
// 100::1 is an official IPv6 black hole but Travis CI has no IPv6.
let uri = "http://240.0.0.1/".parse().unwrap();
let dst = Destination {
uri,
};

let mut connector = HttpConnector::new();
let timeout = Duration::from_millis(1000);
connector.set_connect_timeout(Some(timeout));

let start = Instant::now();
let res = rt.block_on(connector.connect(dst));
let duration = start.elapsed();

match res {
Ok(_) => panic!("Request succeeded but should have timed out"),
Err(error) => if ErrorKind::TimedOut == error.kind() {
// Allow actual duration to be +/- 150ms off.
let allowance = Duration::from_millis(150);
assert!(duration >= timeout - allowance);
assert!(duration <= timeout + allowance);
} else {
panic!("{:?}", error);
},
}
}

#[test]
#[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
fn client_happy_eyeballs() {
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
use std::task::Poll;
use std::time::{Duration, Instant};

use tokio::runtime::current_thread::Runtime;

Expand Down Expand Up @@ -763,7 +875,7 @@ mod tests {
}

let addrs = hosts.iter().map(|host| (host.clone(), addr.port()).into()).collect();
let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), Some(fallback_timeout), false);
let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), None, Some(fallback_timeout), false);
let fut = ConnectingTcpFuture(connecting_tcp);

let start = Instant::now();
Expand Down