Skip to content

Commit

Permalink
feat!: socket2 TCP socket configuration (p.2)
Browse files Browse the repository at this point in the history
With default feature socket2 and the so called crate it is possible to create configurable TCP sockets. Needs MSRV 1.63.

Fixes issue tiny-http#143

There is a the new field socket_config: connection::SocketConfig in ServerConfig.
Call Server::new to create a server with your own config.

The defaults are...
keep_alive: true
no_delay: true
read_timeout: 10s
tcp_keepalive_interval: None
tcp_keepalive_time: 5s
write_timeout: 10s

README.md/Cargo.toml is going back to MSRV 1.60 and describes which feature needs a newer/higher Rust version.
  • Loading branch information
kolbma committed Jan 17, 2024
1 parent 570d7d1 commit 2989a9c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 48 deletions.
64 changes: 47 additions & 17 deletions src/connection.rs
Expand Up @@ -8,20 +8,44 @@ use std::{
};

/// Unified listener. Either a [`TcpListener`] or [`std::os::unix::net::UnixListener`]
#[allow(missing_debug_implementations)]
pub enum Listener {
#[cfg(feature = "socket2")]
Tcp(TcpListener, SocketConfig),
#[cfg(not(feature = "socket2"))]
Tcp(TcpListener),
#[cfg(unix)]
Unix(unix_net::UnixListener),
}
impl Listener {
pub(crate) fn local_addr(&self) -> std::io::Result<ListenAddr> {
match self {
#[cfg(feature = "socket2")]
Self::Tcp(l, _cfg) => l.local_addr().map(ListenAddr::from),
#[cfg(not(feature = "socket2"))]
Self::Tcp(l) => l.local_addr().map(ListenAddr::from),
#[cfg(unix)]
Self::Unix(l) => l.local_addr().map(ListenAddr::from),
}
}

#[cfg(feature = "socket2")]
pub(crate) fn accept(&self) -> std::io::Result<(Connection, Option<SocketAddr>)> {
use log::error;

match self {
Self::Tcp(l, cfg) => l.accept().map(|(mut conn, addr)| {
if let Err(err) = set_socket_cfg(&mut conn, cfg) {
error!("socket config fail: {err:?}");
}
(Connection::from(conn), Some(addr))
}),
#[cfg(unix)]
Self::Unix(l) => l.accept().map(|(conn, _)| (Connection::from(conn), None)),
}
}

#[cfg(not(feature = "socket2"))]
pub(crate) fn accept(&self) -> std::io::Result<(Connection, Option<SocketAddr>)> {
match self {
Self::Tcp(l) => l
Expand All @@ -32,6 +56,13 @@ impl Listener {
}
}
}
#[cfg(feature = "socket2")]
impl From<(TcpListener, SocketConfig)> for Listener {
fn from((s, cfg): (TcpListener, SocketConfig)) -> Self {
Self::Tcp(s, cfg)
}
}
#[cfg(not(feature = "socket2"))]
impl From<TcpListener> for Listener {
fn from(s: TcpListener) -> Self {
Self::Tcp(s)
Expand All @@ -44,6 +75,19 @@ impl From<unix_net::UnixListener> for Listener {
}
}

#[cfg(feature = "socket2")]
#[inline]
fn set_socket_cfg(socket: &mut TcpStream, config: &SocketConfig) -> Result<(), std::io::Error> {
socket.set_nodelay(config.no_delay)?;
if !config.read_timeout.is_zero() {
socket.set_read_timeout(Some(config.read_timeout))?;
}
if !config.write_timeout.is_zero() {
socket.set_write_timeout(Some(config.write_timeout))?;
}
Ok(())
}

/// Unified connection. Either a [`TcpStream`] or [`std::os::unix::net::UnixStream`].
#[derive(Debug)]
pub(crate) enum Connection {
Expand Down Expand Up @@ -117,27 +161,16 @@ impl From<unix_net::UnixStream> for Connection {

#[derive(Debug, Clone)]
pub enum ConfigListenAddr {
#[cfg(not(feature = "socket2"))]
IP(Vec<SocketAddr>),
#[cfg(feature = "socket2")]
IP((Vec<SocketAddr>, Option<SocketConfig>)),
#[cfg(unix)]
// TODO: use SocketAddr when bind_addr is stabilized
Unix(std::path::PathBuf),
Unix(PathBuf),
}
impl ConfigListenAddr {
#[cfg(not(feature = "socket2"))]
pub fn from_socket_addrs<A: ToSocketAddrs>(addrs: A) -> std::io::Result<Self> {
addrs.to_socket_addrs().map(|it| Self::IP(it.collect()))
}

#[cfg(feature = "socket2")]
pub fn from_socket_addrs<A: ToSocketAddrs>(addrs: A) -> std::io::Result<Self> {
addrs
.to_socket_addrs()
.map(|it| Self::IP((it.collect(), None)))
}

#[cfg(unix)]
pub fn unix_from_path<P: Into<PathBuf>>(path: P) -> Self {
Self::Unix(path.into())
Expand All @@ -147,7 +180,7 @@ impl ConfigListenAddr {
pub(crate) fn bind(&self, config: &SocketConfig) -> std::io::Result<Listener> {
match self {
Self::IP(ip) => {
let addresses = &ip.0;
let addresses = ip;
let mut err = None;
let mut socket =
socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?;
Expand Down Expand Up @@ -176,8 +209,6 @@ impl ConfigListenAddr {
}

socket.set_keepalive(config.keep_alive)?;
socket.set_nodelay(config.no_delay)?;
socket.set_read_timeout(Some(config.read_timeout))?;
socket.set_tcp_keepalive(&if let Some(tcp_keepalive_interval) =
config.tcp_keepalive_interval
{
Expand All @@ -187,9 +218,8 @@ impl ConfigListenAddr {
} else {
socket2::TcpKeepalive::new().with_time(config.tcp_keepalive_time)
})?;
socket.set_write_timeout(Some(config.write_timeout))?;

Ok(Listener::Tcp(socket.into()))
Ok(Listener::Tcp(socket.into(), config.clone()))
}
#[cfg(unix)]
Self::Unix(path) => unix_net::UnixListener::bind(path).map(Listener::from),
Expand Down
61 changes: 40 additions & 21 deletions src/lib.rs
Expand Up @@ -87,6 +87,22 @@
//! # let response = tiny_http::Response::from_file(File::open(&Path::new("image.png")).unwrap());
//! let _ = request.respond(response);
//! ```
// #![warn(clippy::pedantic)]
#![warn(
missing_debug_implementations,
// missing_docs,
non_ascii_idents,
rust_2018_compatibility,
trivial_casts,
trivial_numeric_casts,
// unreachable_pub,
unsafe_code,
// unused_crate_dependencies,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
// unused_results
)]
#![forbid(unsafe_code)]
#![deny(rust_2018_idioms)]

Expand All @@ -98,20 +114,20 @@
use zeroize::Zeroizing;

use std::error::Error;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::io::Result as IoResult;
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use std::time::Duration;
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult},
net::{Shutdown, TcpStream, ToSocketAddrs},
};
use std::{
sync::{mpsc, Arc},
thread,
};

use client::ClientConnection;
use connection::Connection;
use util::MessagesQueue;
use util::{MessagesQueue, RefinedTcpStream};

pub use common::{HTTPVersion, Header, HeaderField, Method, StatusCode};
pub use connection::{ConfigListenAddr, ListenAddr, Listener, SocketConfig};
Expand All @@ -134,6 +150,7 @@ mod util;
/// Destroying this object will immediately close the listening socket and the reading
/// part of all the client's connections. Requests that have already been returned by
/// the `recv()` function will not close and the responses will be transferred to the client.
#[allow(missing_debug_implementations)]
pub struct Server {
// should be false as long as the server exists
// when set to true, all the subtasks will close within a few hundreds ms
Expand Down Expand Up @@ -165,14 +182,24 @@ impl From<Request> for Message {

// this trait is to make sure that Server implements Share and Send
#[doc(hidden)]
trait MustBeShareDummy: Sync + Send {}
trait SyncSendT: Sync + Send {}
#[doc(hidden)]
impl MustBeShareDummy for Server {}
impl SyncSendT for Server {}

/// Iterator over received [Request] from [Server]
#[allow(missing_debug_implementations)]
pub struct IncomingRequests<'a> {
server: &'a Server,
}

impl Iterator for IncomingRequests<'_> {
type Item = Request;

fn next(&mut self) -> Option<Request> {
self.server.recv().ok()
}
}

/// Represents the parameters required to create a server.
#[derive(Debug, Clone)]
pub struct ServerConfig {
Expand All @@ -182,7 +209,7 @@ pub struct ServerConfig {
/// Socket configuration with _socket2_ feature
/// See [SocketConfig]
#[cfg(feature = "socket2")]
pub socket_config: connection::SocketConfig,
pub socket_config: SocketConfig,

/// If `Some`, then the server will use SSL to encode the communications.
pub ssl: Option<SslConfig>,
Expand Down Expand Up @@ -335,7 +362,6 @@ impl Server {
while !inside_close_trigger.load(Relaxed) {
let new_client = match server.accept() {
Ok((sock, _)) => {
use util::RefinedTcpStream;
let (read_closable, write_closable) = match ssl {
None => RefinedTcpStream::new(sock),
#[cfg(any(
Expand Down Expand Up @@ -461,13 +487,6 @@ impl Server {
}
}

impl Iterator for IncomingRequests<'_> {
type Item = Request;
fn next(&mut self) -> Option<Request> {
self.server.recv().ok()
}
}

impl Drop for Server {
fn drop(&mut self) {
self.close.store(true, Relaxed);
Expand Down
81 changes: 71 additions & 10 deletions tests/network.rs
Expand Up @@ -187,12 +187,14 @@ fn connection_timeout() -> Result<(), std::io::Error> {
use std::time::{Duration, Instant};
use tiny_http::ServerConfig;

let now = Instant::now();

let (server, mut client) = {
let server = tiny_http::Server::new(ServerConfig {
addr: tiny_http::ConfigListenAddr::from_socket_addrs("0.0.0.0:0")?,
socket_config: tiny_http::SocketConfig {
read_timeout: Duration::from_millis(500),
write_timeout: Duration::from_millis(500),
read_timeout: Duration::from_millis(100),
write_timeout: Duration::from_millis(100),
..tiny_http::SocketConfig::default()
},
ssl: None,
Expand All @@ -203,28 +205,87 @@ fn connection_timeout() -> Result<(), std::io::Error> {
(server, client)
};

thread::spawn(move || {
let rq = server.recv_timeout(Duration::from_secs(300));
assert!(rq.is_ok(), "req fail: {}", rq.unwrap_err());

let rq = rq.unwrap();
assert!(rq.is_some());
let rq = rq.unwrap();

let resp = tiny_http::Response::empty(tiny_http::StatusCode(204));
rq.respond(resp).unwrap();
});

write!(client, "GET / HTTP/1.1\r\n\r\n")?;

let mut content = String::new();
client.read_to_string(&mut content).unwrap();
assert!(content.starts_with("HTTP/1.1 204"));

thread::sleep(Duration::from_millis(200));

let err = write!(
client,
"GET / HTTP/1.1\r\nHost: localhost\r\nTE: chunked\r\nConnection: close\r\n\r\n"
);
assert!(err.is_ok());

let elaps = now.elapsed();
assert!(
elaps > Duration::from_millis(230) && elaps < Duration::from_millis(320),
"elaps: {}",
elaps.as_millis()
);

Ok(())
}

#[test]
#[cfg(feature = "socket2")]
fn connection_timeout_wait_check() -> Result<(), std::io::Error> {
use std::time::{Duration, Instant};
use tiny_http::ServerConfig;

let now = Instant::now();

write!(client, "GET / HTTP/1.").unwrap();
let (server, mut client) = {
let server = tiny_http::Server::new(ServerConfig {
addr: tiny_http::ConfigListenAddr::from_socket_addrs("0.0.0.0:0")?,
socket_config: tiny_http::SocketConfig {
read_timeout: Duration::from_millis(250),
write_timeout: Duration::from_millis(250),
..tiny_http::SocketConfig::default()
},
ssl: None,
})
.unwrap();
let port = server.server_addr().to_ip().unwrap().port();
let client = TcpStream::connect(("127.0.0.1", port)).unwrap();
(server, client)
};

let h = thread::spawn(move || {
let rq = server.recv_timeout(Duration::from_secs(2));
thread::spawn(move || {
let rq = server.recv_timeout(Duration::from_secs(300));
assert!(rq.is_err());
});

// thread::sleep(Duration::from_millis(3000));
// make sure it is waiting longer than server timeouts
thread::sleep(Duration::from_millis(300));

let _ = h.join();
let err = write!(
client,
"GET / HTTP/1.1\r\nHost: localhost\r\nTE: chunked\r\nConnection: close\r\n\r\n"
);
assert!(err.is_ok());

let elaps = now.elapsed();
assert!(
elaps > Duration::from_millis(490) && elaps < Duration::from_millis(540),
elaps > Duration::from_millis(300) && elaps < Duration::from_millis(330),
"elaps: {}",
elaps.as_millis()
);

drop(client);

Ok(())
}

Expand Down

0 comments on commit 2989a9c

Please sign in to comment.