From 4968af384930c6c1156743d323c6c491f3b7f69b Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Thu, 25 Aug 2022 15:57:29 +0200 Subject: [PATCH] Add support for unbuffered Response This is achieved through two different means: - The ability to disable buffering inside `Response` objects with the `Response::with_buffering` method. Enabling this will force the transfer encoding to be `TransferEncoding::Chunked` and will configure the chunks encoder to flush to its underlying writer on every write. - To get "instantaneous" write, disabling buffering in the chunks encoder is not enough, as the underlying writer returned when calling `Server::recv()` (`ClientConnection.sink`) is in fact a `BufWriter` wrapping the "real" output. The `writer_buffering` parameter in `ServerConfig.advanced` can alter the server behavior to omit the BufWriter when writing to the TcpStream. The cost of that abstraction is that `ClientConnection.sink` now boxes the writer to be able to choose between `BufWriter` and `RefinedTcpStream` dynamically, which means there is now one additional pointer deference. However, this pointer is then stored in an Arc>, and locking/unlocking the mutex is probbaly more expensive that deferencing the pointer. This will probably decrease performance significantly when sending big files, which is why these two subfeatures are disabled by default, and must be opted-in (by calling the `with_buffering` method for the first, and by instanciating the server with the `with_writer_buffering_mode` method for the second). --- src/client.rs | 14 ++++-- src/lib.rs | 57 ++++++++++++++++++++-- src/response.rs | 28 ++++++++++- src/util/buffering_wrapper.rs | 29 +++++++++++ src/util/mod.rs | 2 + tests/promptness.rs | 36 +++++--------- tests/simple-test.rs | 90 +++++++++++++++++++++++++++++++++-- tests/support/mod.rs | 51 ++++++++++++++++++-- 8 files changed, 268 insertions(+), 39 deletions(-) create mode 100644 src/util/buffering_wrapper.rs diff --git a/src/client.rs b/src/client.rs index 68bbd469..02d9f812 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,9 +8,10 @@ use std::net::SocketAddr; use std::str::FromStr; use crate::common::{HTTPVersion, Method}; +use crate::util::MaybeBufferedWriter; use crate::util::RefinedTcpStream; use crate::util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder}; -use crate::Request; +use crate::{BufferingMode, Request}; /// A ClientConnection is an object that will store a socket to a client /// and return Request objects. @@ -24,7 +25,7 @@ pub struct ClientConnection { // sequence of Writers to the stream, to avoid writing response #2 before // response #1 - sink: SequentialWriterBuilder>, + sink: SequentialWriterBuilder>, // Reader to read the next header from next_header_source: SequentialReader>, @@ -48,9 +49,12 @@ enum ReadError { impl ClientConnection { /// Creates a new `ClientConnection` that takes ownership of the `TcpStream`. + /// The `buffering_control` parameter allows for selecting the desired + /// buffering mode. pub fn new( write_socket: RefinedTcpStream, mut read_socket: RefinedTcpStream, + buffering_mode: BufferingMode, ) -> ClientConnection { let remote_addr = read_socket.peer_addr(); let secure = read_socket.secure(); @@ -60,7 +64,11 @@ impl ClientConnection { ClientConnection { source, - sink: SequentialWriterBuilder::new(BufWriter::with_capacity(1024, write_socket)), + sink: SequentialWriterBuilder::new(if let BufferingMode::Buffered = buffering_mode { + MaybeBufferedWriter::Buffered(BufWriter::with_capacity(1024, write_socket)) + } else { + MaybeBufferedWriter::Unbuffered(write_socket) + }), remote_addr, next_header_source: first_header, no_more_requests: false, diff --git a/src/lib.rs b/src/lib.rs index a04116da..a6b35ce6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,6 +170,44 @@ pub struct IncomingRequests<'a> { server: &'a Server, } +/// Control buffering in a data stream. +#[derive(Debug, Copy, Clone)] +pub enum BufferingMode { + /// Buffer the messages. This is the default. + Buffered, + /// Do not buffer the messages, this prevent caching but will probably incur a performance cost. + Unbuffered, +} + +impl Default for BufferingMode { + fn default() -> Self { + BufferingMode::Buffered + } +} + +/// Advanced server settings. +/// In order to retain the ability to add options later while preserving the +/// API, this is an "opaque" struct, to be manipulated through the `new` and +/// `with_*` configuration methods. +#[derive(Default, Debug, Clone)] +pub struct ServerConfigAdvanced { + /// Control buffering on the server->client path. The default value is + /// `BufferingMode::Buffered`. + writer_buffering: BufferingMode, +} + +impl ServerConfigAdvanced { + pub fn new() -> Self { + Self::default() + } + + /// Change the buffering mode of the writer returned with each message. + pub fn with_writer_buffering_mode(mut self, mode: BufferingMode) -> Self { + self.writer_buffering = mode; + self + } +} + /// Represents the parameters required to create a server. #[derive(Debug, Clone)] pub struct ServerConfig { @@ -178,6 +216,9 @@ pub struct ServerConfig { /// If `Some`, then the server will use SSL to encode the communications. pub ssl: Option, + + /// Advanced server settings. + pub advanced: ServerConfigAdvanced, } /// Configuration of the server for SSL. @@ -199,6 +240,7 @@ impl Server { Server::new(ServerConfig { addr: ConfigListenAddr::from_socket_addrs(addr)?, ssl: None, + advanced: ServerConfigAdvanced::new(), }) } @@ -215,6 +257,7 @@ impl Server { Server::new(ServerConfig { addr: ConfigListenAddr::from_socket_addrs(addr)?, ssl: Some(config), + advanced: ServerConfigAdvanced::new(), }) } @@ -227,13 +270,14 @@ impl Server { Server::new(ServerConfig { addr: ConfigListenAddr::unix_from_path(path), ssl: None, + advanced: ServerConfigAdvanced::new(), }) } /// Builds a new server that listens on the specified address. pub fn new(config: ServerConfig) -> Result> { let listener = config.addr.bind()?; - Self::from_listener(listener, config.ssl) + Self::from_listener(listener, config) } /// Builds a new server using the specified TCP listener. @@ -242,7 +286,7 @@ impl Server { /// such as from systemd. For other cases, you probably want the `new()` function. pub fn from_listener>( listener: L, - ssl_config: Option, + config: ServerConfig, ) -> Result> { let listener = listener.into(); // building the "close" variable @@ -265,7 +309,7 @@ impl Server { #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))] type SslContext = crate::ssl::SslContextImpl; let ssl: Option = { - match ssl_config { + match config.ssl { #[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))] Some(config) => Some(SslContext::from_pem( config.certificate, @@ -286,6 +330,7 @@ impl Server { let inside_close_trigger = close_trigger.clone(); let inside_messages = messages.clone(); + let writer_buffering = config.advanced.writer_buffering; thread::spawn(move || { // a tasks pool is used to dispatch the connections into threads let tasks_pool = util::TaskPool::new(); @@ -312,7 +357,11 @@ impl Server { Some(ref _ssl) => unreachable!(), }; - Ok(ClientConnection::new(write_closable, read_closable)) + Ok(ClientConnection::new( + write_closable, + read_closable, + writer_buffering, + )) } Err(e) => Err(e), }; diff --git a/src/response.rs b/src/response.rs index aaedf5c7..3cda2341 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,4 +1,5 @@ use crate::common::{HTTPVersion, Header, StatusCode}; +use crate::BufferingMode; use httpdate::HttpDate; use std::cmp::Ordering; use std::sync::mpsc::Receiver; @@ -43,6 +44,7 @@ pub struct Response { headers: Vec
, data_length: Option, chunked_threshold: Option, + buffering: BufferingMode, } /// A `Response` without a template parameter. @@ -116,6 +118,7 @@ fn choose_transfer_encoding( entity_length: &Option, has_additional_headers: bool, chunked_threshold: usize, + buffering: BufferingMode, ) -> TransferEncoding { use crate::util; @@ -166,6 +169,12 @@ fn choose_transfer_encoding( return user_request; } + // unbuffered messages must use chunked transfer encoding, to send messages as they + // are produced + if let BufferingMode::Unbuffered = buffering { + return TransferEncoding::Chunked; + } + // if we have additional headers, using chunked if has_additional_headers { return TransferEncoding::Chunked; @@ -206,6 +215,7 @@ where headers: Vec::with_capacity(16), data_length, chunked_threshold: None, + buffering: BufferingMode::Buffered, }; for h in headers { @@ -231,6 +241,14 @@ where self } + /// Define if the output should be buffered. If the output in unbuffered, + /// every write to the socket will be followed by a flush, and the output + /// will be chunked. + pub fn with_buffering(mut self, buffering: BufferingMode) -> Self { + self.buffering = buffering; + self + } + /// Convert the response into the underlying `Read` type. /// /// This is mainly useful for testing as it must consume the `Response`. @@ -319,6 +337,7 @@ where status_code: self.status_code, data_length, chunked_threshold: self.chunked_threshold, + buffering: BufferingMode::Buffered, } } @@ -346,6 +365,7 @@ where &self.data_length, false, /* TODO */ self.chunked_threshold(), + self.buffering, )); // add `Date` if not in the headers @@ -433,7 +453,11 @@ where Some(TransferEncoding::Chunked) => { use chunked_transfer::Encoder; - let mut writer = Encoder::new(writer); + let mut writer = if let BufferingMode::Unbuffered = self.buffering { + Encoder::with_flush_after_write(writer) + } else { + Encoder::new(writer) + }; io::copy(&mut reader, &mut writer)?; } @@ -481,6 +505,7 @@ where headers: self.headers, data_length: self.data_length, chunked_threshold: self.chunked_threshold, + buffering: self.buffering, } } } @@ -569,6 +594,7 @@ impl Clone for Response { headers: self.headers.clone(), data_length: self.data_length, chunked_threshold: self.chunked_threshold, + buffering: self.buffering, } } } diff --git a/src/util/buffering_wrapper.rs b/src/util/buffering_wrapper.rs new file mode 100644 index 00000000..45eb6229 --- /dev/null +++ b/src/util/buffering_wrapper.rs @@ -0,0 +1,29 @@ +use std::io::{BufWriter, Result as IoResult, Write}; + +pub enum MaybeBufferedWriter { + Buffered(BufWriter), + Unbuffered(W), +} + +impl Write for MaybeBufferedWriter { + fn write(&mut self, buf: &[u8]) -> IoResult { + match self { + MaybeBufferedWriter::Buffered(w) => w.write(buf), + MaybeBufferedWriter::Unbuffered(w) => w.write(buf), + } + } + + fn write_all(&mut self, buf: &[u8]) -> IoResult<()> { + match self { + MaybeBufferedWriter::Buffered(w) => w.write_all(buf), + MaybeBufferedWriter::Unbuffered(w) => w.write_all(buf), + } + } + + fn flush(&mut self) -> IoResult<()> { + match self { + MaybeBufferedWriter::Buffered(w) => w.flush(), + MaybeBufferedWriter::Unbuffered(w) => w.flush(), + } + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 4fb2aca5..0a486e5b 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,3 +1,4 @@ +pub use self::buffering_wrapper::MaybeBufferedWriter; pub use self::custom_stream::CustomStream; pub use self::equal_reader::EqualReader; pub use self::fused_reader::FusedReader; @@ -9,6 +10,7 @@ pub use self::task_pool::TaskPool; use std::str::FromStr; +mod buffering_wrapper; mod custom_stream; mod equal_reader; mod fused_reader; diff --git a/tests/promptness.rs b/tests/promptness.rs index a621d3e5..dc0845d6 100644 --- a/tests/promptness.rs +++ b/tests/promptness.rs @@ -1,31 +1,17 @@ extern crate tiny_http; use std::io::{copy, Read, Write}; -use std::net::{Shutdown, TcpStream}; +use std::net::Shutdown; use std::ops::Deref; use std::sync::mpsc::channel; use std::sync::Arc; -use std::thread::{sleep, spawn}; +use std::thread::spawn; use std::time::Duration; -use tiny_http::{Response, Server}; +use tiny_http::Response; -/// Stream that produces bytes very slowly -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -struct SlowByteSrc { - val: u8, - len: usize, -} -impl<'b> Read for SlowByteSrc { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - sleep(Duration::from_millis(100)); - let l = self.len.min(buf.len()).min(1000); - for v in buf[..l].iter_mut() { - *v = self.val; - } - self.len -= l; - Ok(l) - } -} +#[allow(dead_code)] +mod support; +use support::{new_one_server_one_client, SlowByteSrc}; /// crude impl of http `Transfer-Encoding: chunked` fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) { @@ -42,6 +28,8 @@ fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) { } mod prompt_pipelining { + use std::time::Duration; + use super::*; /// Check that pipelined requests on the same connection are received promptly. @@ -52,12 +40,12 @@ mod prompt_pipelining { req_writer: impl FnOnce(&mut dyn Write) + Send + 'static, ) { let resp_body = SlowByteSrc { + sleep_time: Duration::from_millis(100), val: 42, len: 1000_000, }; // very slow response body - let server = Server::http("0.0.0.0:0").unwrap(); - let mut client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap(); + let (server, mut client) = new_one_server_one_client(); let (svr_send, svr_rcv) = channel(); spawn(move || { @@ -142,8 +130,7 @@ mod prompt_responses { timeout: Duration, req_writer: impl FnOnce(&mut dyn Write) + Send + 'static, ) { - let server = Server::http("0.0.0.0:0").unwrap(); - let client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap(); + let (server, client) = new_one_server_one_client(); spawn(move || loop { // server attempts to respond immediately @@ -164,6 +151,7 @@ mod prompt_responses { } static SLOW_BODY: SlowByteSrc = SlowByteSrc { + sleep_time: Duration::from_millis(100), val: 65, len: 1000_000, }; diff --git a/tests/simple-test.rs b/tests/simple-test.rs index 4375109d..1e6a0dff 100644 --- a/tests/simple-test.rs +++ b/tests/simple-test.rs @@ -1,13 +1,18 @@ extern crate tiny_http; -use std::io::{Read, Write}; +use std::{ + io::{Read, Write}, + time::{Duration, Instant}, +}; #[allow(dead_code)] mod support; +use chunked_transfer::Decoder; +use support::{new_one_server_one_client, new_one_server_one_client_unbuffered, SlowByteSrc}; #[test] fn basic_handling() { - let (server, mut stream) = support::new_one_server_one_client(); + let (server, mut stream) = new_one_server_one_client(); write!( stream, "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" @@ -16,7 +21,7 @@ fn basic_handling() { let request = server.recv().unwrap(); assert!(*request.method() == tiny_http::Method::Get); - //assert!(request.url() == "/"); + assert!(request.url() == "/"); request .respond(tiny_http::Response::from_string("hello world".to_owned())) .unwrap(); @@ -27,3 +32,82 @@ fn basic_handling() { stream.read_to_string(&mut content).unwrap(); assert!(content.ends_with("hello world")); } + +#[test] +fn unbuffered() { + let (server, mut stream) = new_one_server_one_client_unbuffered(); + write!( + stream, + "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + .unwrap(); + + let request = server.recv().unwrap(); + let start = Instant::now(); + std::thread::spawn(|| { + request + .respond( + tiny_http::Response::new( + tiny_http::StatusCode(200), + Vec::new(), + SlowByteSrc { + sleep_time: Duration::from_millis(50), + val: 65, + // extreme length (100GB): we ensure that the data + // is streamed on demand instead of assembled in memory + len: 100_000_000_000, + }, + None, + None, + ) + .with_buffering(tiny_http::BufferingMode::Unbuffered), + ) + .unwrap() + }); + + assert!(server.try_recv().unwrap().is_none()); + + let mut buf = [0; 64 * 1024]; + let mut read = 0; + + loop { + let nb_read = stream.read(&mut buf[read..]).unwrap(); + + // we should receive some data, but only a small amount because of the lack + // of buffering + assert!(nb_read > 0); + assert!(nb_read < buf.len()); + + read += nb_read; + + // ensure that we receive the data quickly after the SlowByteSrc reader + // started feeeding the server + let elapsed = start.elapsed(); + if elapsed > Duration::from_millis(75) { + break; + } + } + + let res = String::from_utf8(buf[..read].to_vec()).expect("Invalid UTF8 characters"); + assert!(res.contains("Transfer-Encoding: chunked")); + + let chunked_index = res + .find("\r\n\r\n") + .expect("Could not find the start of the chunked messages"); + let mut chunked_data = res[chunked_index + 4..].to_string(); + // emit the "end of stream" message + chunked_data.push_str("0\r\n\r\n"); + + // verify that we only received '\x65' characters + let mut decoder = Decoder::new(chunked_data.as_bytes()); + let mut decoded = String::new(); + decoder + .read_to_string(&mut decoded) + .expect("Invalid (non-chunked?) data"); + let bytes = decoded.as_bytes(); + let mut expected_vec = Vec::with_capacity(bytes.len()); + for _ in 0..bytes.len() { + expected_vec.push(65); + } + assert!(bytes == expected_vec.as_slice()); +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 7a4dc587..675d31e2 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,15 +1,38 @@ -use std::net::TcpStream; -use std::thread; +use std::thread::{self, sleep}; use std::time::Duration; +use std::{io::Read, net::TcpStream}; + +use tiny_http::{ConfigListenAddr, ServerConfig, ServerConfigAdvanced}; /// Creates a server and a client connected to the server. -pub fn new_one_server_one_client() -> (tiny_http::Server, TcpStream) { - let server = tiny_http::Server::http("0.0.0.0:0").unwrap(); +pub fn new_one_server_one_client_custom_config( + config: ServerConfig, +) -> (tiny_http::Server, TcpStream) { + let server = tiny_http::Server::new(config).unwrap(); let port = server.server_addr().to_ip().unwrap().port(); let client = TcpStream::connect(("127.0.0.1", port)).unwrap(); (server, client) } +/// Creates a server and a client connected to the server. +pub fn new_one_server_one_client() -> (tiny_http::Server, TcpStream) { + new_one_server_one_client_custom_config(ServerConfig { + addr: ConfigListenAddr::from_socket_addrs("0.0.0.0:0").unwrap(), + ssl: None, + advanced: ServerConfigAdvanced::new(), + }) +} + +/// Creates a server and a client connected to the server, with an unbuffered writer. +pub fn new_one_server_one_client_unbuffered() -> (tiny_http::Server, TcpStream) { + new_one_server_one_client_custom_config(ServerConfig { + addr: ConfigListenAddr::from_socket_addrs("0.0.0.0:0").unwrap(), + ssl: None, + advanced: ServerConfigAdvanced::new() + .with_writer_buffering_mode(tiny_http::BufferingMode::Unbuffered), + }) +} + /// Creates a "hello world" server with a client connected to the server. /// /// The server will automatically close after 3 seconds. @@ -38,3 +61,23 @@ pub fn new_client_to_hello_world_server() -> TcpStream { client } + +/// Stream that produces bytes very slowly +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct SlowByteSrc { + pub sleep_time: Duration, + pub val: u8, + pub len: usize, +} + +impl<'b> Read for SlowByteSrc { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + sleep(self.sleep_time); + let l = self.len.min(buf.len()).min(1000); + for v in buf[..l].iter_mut() { + *v = self.val; + } + self.len -= l; + Ok(l) + } +}