From 8fcb8077c079a461d14035b78c00572eadb8cc01 Mon Sep 17 00:00:00 2001 From: Hanson Char Date: Wed, 21 Sep 2022 12:13:52 -0700 Subject: [PATCH] Supports configuration of all TCP Keepalive parameters per https://en.wikipedia.org/wiki/Keepalive https://github.com/hyperium/hyper/pull/2991 --- src/server/server.rs | 23 ++++++-- src/server/tcp.rs | 135 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 140 insertions(+), 18 deletions(-) diff --git a/src/server/server.rs b/src/server/server.rs index e3058da4bb..caa4aebd14 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -2,10 +2,12 @@ use std::error::Error as StdError; use std::fmt; #[cfg(feature = "tcp")] use std::net::{SocketAddr, TcpListener as StdTcpListener}; -#[cfg(any(feature = "tcp", feature = "http1"))] + +#[cfg(feature = "tcp")] use std::time::Duration; use pin_project_lite::pin_project; + use tokio::io::{AsyncRead, AsyncWrite}; use tracing::trace; @@ -559,16 +561,27 @@ impl Builder { doc(cfg(all(feature = "tcp", any(feature = "http1", feature = "http2")))) )] impl Builder { - /// Set whether TCP keepalive messages are enabled on accepted connections. + /// Set the duration to remain idle before sending TCP keepalive probes. /// - /// If `None` is specified, keepalive is disabled, otherwise the duration - /// specified will be the time to remain idle before sending TCP keepalive - /// probes. + /// If `None` is specified, keepalive is disabled. pub fn tcp_keepalive(mut self, keepalive: Option) -> Self { self.incoming.set_keepalive(keepalive); self } + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + pub fn tcp_keepalive_interval(mut self, interval: Option) -> Self { + self.incoming.set_keepalive_interval(interval); + self + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + pub fn tcp_keepalive_retries(mut self, retries: Option) -> Self { + self.incoming.set_keepalive_retries(retries); + self + } + /// Set the value of `TCP_NODELAY` option for accepted connections. pub fn tcp_nodelay(mut self, enabled: bool) -> Self { self.incoming.set_nodelay(enabled); diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 7e70ce3ac3..f475c09b1c 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -2,6 +2,7 @@ use std::fmt; use std::io; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::time::Duration; +use socket2::TcpKeepalive; use tokio::net::TcpListener; use tokio::time::Sleep; @@ -13,13 +14,65 @@ use crate::common::{task, Future, Pin, Poll}; pub use self::addr_stream::AddrStream; use super::accept::Accept; +#[derive(Default, Debug, Clone, Copy)] +struct TcpKeepaliveConfig { + time: Option, + interval: Option, + retries: Option, +} + +impl TcpKeepaliveConfig { + /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. + fn into_socket2(self) -> Option { + let mut dirty = false; + let mut ka = TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + dirty = true + } + if let Some(interval) = self.interval { + ka = Self::ka_with_interval(ka, interval, &mut dirty) + }; + if let Some(retries) = self.retries { + ka = Self::ka_with_retries(ka, retries, &mut dirty) + }; + if dirty { + Some(ka) + } else { + None + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))] + fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_interval(interval) + } + + #[cfg(any(target_os = "openbsd", target_os = "redox", target_os = "solaris"))] + fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive interval is not supported on this platform + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows")))] + fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive { + *dirty = true; + ka.with_retries(retries) + } + + #[cfg(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows"))] + fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive { + ka // no-op as keepalive retries is not supported on this platform + } +} + /// A stream of connections from binding to an address. #[must_use = "streams do nothing unless polled"] pub struct AddrIncoming { addr: SocketAddr, listener: TcpListener, sleep_on_errors: bool, - tcp_keepalive_timeout: Option, + tcp_keepalive_config: TcpKeepaliveConfig, tcp_nodelay: bool, timeout: Option>>, } @@ -52,7 +105,7 @@ impl AddrIncoming { listener, addr, sleep_on_errors: true, - tcp_keepalive_timeout: None, + tcp_keepalive_config: TcpKeepaliveConfig::default(), tcp_nodelay: false, timeout: None, }) @@ -63,13 +116,24 @@ impl AddrIncoming { self.addr } - /// Set whether TCP keepalive messages are enabled on accepted connections. + /// Set the duration to remain idle before sending TCP keepalive probes. /// - /// If `None` is specified, keepalive is disabled, otherwise the duration - /// specified will be the time to remain idle before sending TCP keepalive - /// probes. - pub fn set_keepalive(&mut self, keepalive: Option) -> &mut Self { - self.tcp_keepalive_timeout = keepalive; + /// If `None` is specified, keepalive is disabled. + pub fn set_keepalive(&mut self, time: Option) -> &mut Self { + self.tcp_keepalive_config.time = time; + self + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + pub fn set_keepalive_interval(&mut self, interval: Option) -> &mut Self { + self.tcp_keepalive_config.interval = interval; + self + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + pub fn set_keepalive_retries(&mut self, retries: Option) -> &mut Self { + self.tcp_keepalive_config.retries = retries; self } @@ -108,10 +172,9 @@ impl AddrIncoming { loop { match ready!(self.listener.poll_accept(cx)) { Ok((socket, remote_addr)) => { - if let Some(dur) = self.tcp_keepalive_timeout { - let socket = socket2::SockRef::from(&socket); - let conf = socket2::TcpKeepalive::new().with_time(dur); - if let Err(e) = socket.set_tcp_keepalive(&conf) { + if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() { + let sock_ref = socket2::SockRef::from(&socket); + if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) { trace!("error trying to set TCP keepalive: {}", e); } } @@ -188,7 +251,7 @@ impl fmt::Debug for AddrIncoming { f.debug_struct("AddrIncoming") .field("addr", &self.addr) .field("sleep_on_errors", &self.sleep_on_errors) - .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout) + .field("tcp_keepalive_config", &self.tcp_keepalive_config) .field("tcp_nodelay", &self.tcp_nodelay) .finish() } @@ -316,3 +379,49 @@ mod addr_stream { } } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + use crate::server::tcp::TcpKeepaliveConfig; + + #[test] + fn no_tcp_keepalive_config() { + assert!(TcpKeepaliveConfig::default().into_socket2().is_none()); + } + + #[test] + fn tcp_keepalive_time_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.time = Some(Duration::from_secs(60)); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))] + #[test] + fn tcp_keepalive_interval_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.interval = Some(Duration::from_secs(1)); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)")); + } else { + panic!("test failed"); + } + } + + #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows")))] + #[test] + fn tcp_keepalive_retries_config() { + let mut kac = TcpKeepaliveConfig::default(); + kac.retries = Some(3); + if let Some(tcp_keepalive) = kac.into_socket2() { + assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)")); + } else { + panic!("test failed"); + } + } +}