Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for unbuffered writes #229

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/client.rs
Expand Up @@ -8,9 +8,10 @@ use std::net::SocketAddr;
use std::str::FromStr;

use crate::common::{HTTPVersion, Method};
use crate::util::MaybeBufferedWriter;
use crate::util::RefinedTcpStream;
use crate::util::{SequentialReader, SequentialReaderBuilder, SequentialWriterBuilder};
use crate::Request;
use crate::{BufferingMode, Request};

/// A ClientConnection is an object that will store a socket to a client
/// and return Request objects.
Expand All @@ -24,7 +25,7 @@ pub struct ClientConnection {

// sequence of Writers to the stream, to avoid writing response #2 before
// response #1
sink: SequentialWriterBuilder<BufWriter<RefinedTcpStream>>,
sink: SequentialWriterBuilder<MaybeBufferedWriter<RefinedTcpStream>>,

// Reader to read the next header from
next_header_source: SequentialReader<BufReader<RefinedTcpStream>>,
Expand All @@ -48,9 +49,12 @@ enum ReadError {

impl ClientConnection {
/// Creates a new `ClientConnection` that takes ownership of the `TcpStream`.
/// The `buffering_control` parameter allows for selecting the desired
/// buffering mode.
pub fn new(
write_socket: RefinedTcpStream,
mut read_socket: RefinedTcpStream,
buffering_mode: BufferingMode,
) -> ClientConnection {
let remote_addr = read_socket.peer_addr();
let secure = read_socket.secure();
Expand All @@ -60,7 +64,11 @@ impl ClientConnection {

ClientConnection {
source,
sink: SequentialWriterBuilder::new(BufWriter::with_capacity(1024, write_socket)),
sink: SequentialWriterBuilder::new(if let BufferingMode::Buffered = buffering_mode {
MaybeBufferedWriter::Buffered(BufWriter::with_capacity(1024, write_socket))
} else {
MaybeBufferedWriter::Unbuffered(write_socket)
}),
remote_addr,
next_header_source: first_header,
no_more_requests: false,
Expand Down
57 changes: 53 additions & 4 deletions src/lib.rs
Expand Up @@ -170,6 +170,44 @@
server: &'a Server,
}

/// Control buffering in a data stream.
#[derive(Debug, Copy, Clone)]
pub enum BufferingMode {
/// Buffer the messages. This is the default.
Buffered,
/// Do not buffer the messages, this prevent caching but will probably incur a performance cost.
Unbuffered,
}

impl Default for BufferingMode {

Check warning on line 182 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Lint & Format (default)

this `impl` can be derived

Check warning on line 182 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Lint & Format (ssl-openssl)

this `impl` can be derived

Check warning on line 182 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Lint & Format (ssl-rustls)

this `impl` can be derived
fn default() -> Self {
BufferingMode::Buffered
}
}

/// Advanced server settings.
/// In order to retain the ability to add options later while preserving the
/// API, this is an "opaque" struct, to be manipulated through the `new` and
/// `with_*` configuration methods.
#[derive(Default, Debug, Clone)]
pub struct ServerConfigAdvanced {
/// Control buffering on the server->client path. The default value is
/// `BufferingMode::Buffered`.
writer_buffering: BufferingMode,
}

impl ServerConfigAdvanced {
pub fn new() -> Self {
Self::default()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I'm not certain providing a default impl of ServerConfigAdvanced is a great idea, maybe I should drop that and explicitly instantiate the configuration in the new method.
Because the Default impl is public-facing, so if we commit to that, we will need to maintain a Default impl on that structure in the future.
That said, we do have to provide default values anyway (wether in the default impl or in the new method), so I don't feel strongly about this either way. But if the idea is to provide some opaque structure with a builder pattern, I guess the less we expose API-wide, the better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if providing a Defaul implementation is any more or less public than an opaque new(), in that we would need to preserve any behavioural expectations anyway - I would like a builder pattern configuration mechanism anyway, but it doesn't need to be at the expense of Default.

}

/// Change the buffering mode of the writer returned with each message.
pub fn with_writer_buffering_mode(mut self, mode: BufferingMode) -> Self {
self.writer_buffering = mode;
self
}
}

/// Represents the parameters required to create a server.
#[derive(Debug, Clone)]
pub struct ServerConfig {
Expand All @@ -178,6 +216,9 @@

/// If `Some`, then the server will use SSL to encode the communications.
pub ssl: Option<SslConfig>,

/// Advanced server settings.
pub advanced: ServerConfigAdvanced,
}

/// Configuration of the server for SSL.
Expand All @@ -199,6 +240,7 @@
Server::new(ServerConfig {
addr: ConfigListenAddr::from_socket_addrs(addr)?,
ssl: None,
advanced: ServerConfigAdvanced::new(),
})
}

Expand All @@ -215,6 +257,7 @@
Server::new(ServerConfig {
addr: ConfigListenAddr::from_socket_addrs(addr)?,
ssl: Some(config),
advanced: ServerConfigAdvanced::new(),
})
}

Expand All @@ -227,13 +270,14 @@
Server::new(ServerConfig {
addr: ConfigListenAddr::unix_from_path(path),
ssl: None,
advanced: ServerConfigAdvanced::new(),
})
}

/// Builds a new server that listens on the specified address.
pub fn new(config: ServerConfig) -> Result<Server, Box<dyn Error + Send + Sync + 'static>> {
let listener = config.addr.bind()?;
Self::from_listener(listener, config.ssl)
Self::from_listener(listener, config)
}

/// Builds a new server using the specified TCP listener.
Expand All @@ -242,7 +286,7 @@
/// such as from systemd. For other cases, you probably want the `new()` function.
pub fn from_listener<L: Into<Listener>>(
listener: L,
ssl_config: Option<SslConfig>,
config: ServerConfig,
) -> Result<Server, Box<dyn Error + Send + Sync + 'static>> {
let listener = listener.into();
// building the "close" variable
Expand All @@ -265,7 +309,7 @@
#[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
type SslContext = crate::ssl::SslContextImpl;
let ssl: Option<SslContext> = {
match ssl_config {
match config.ssl {
#[cfg(any(feature = "ssl-openssl", feature = "ssl-rustls"))]
Some(config) => Some(SslContext::from_pem(
config.certificate,
Expand All @@ -286,6 +330,7 @@

let inside_close_trigger = close_trigger.clone();
let inside_messages = messages.clone();
let writer_buffering = config.advanced.writer_buffering;
thread::spawn(move || {
// a tasks pool is used to dispatch the connections into threads
let tasks_pool = util::TaskPool::new();
Expand All @@ -312,7 +357,11 @@
Some(ref _ssl) => unreachable!(),
};

Ok(ClientConnection::new(write_closable, read_closable))
Ok(ClientConnection::new(
write_closable,
read_closable,
writer_buffering,
))
}
Err(e) => Err(e),
};
Expand Down
28 changes: 27 additions & 1 deletion src/response.rs
@@ -1,4 +1,5 @@
use crate::common::{HTTPVersion, Header, StatusCode};
use crate::BufferingMode;
use httpdate::HttpDate;
use std::cmp::Ordering;
use std::sync::mpsc::Receiver;
Expand Down Expand Up @@ -43,6 +44,7 @@ pub struct Response<R> {
headers: Vec<Header>,
data_length: Option<usize>,
chunked_threshold: Option<usize>,
buffering: BufferingMode,
}

/// A `Response` without a template parameter.
Expand Down Expand Up @@ -116,6 +118,7 @@ fn choose_transfer_encoding(
entity_length: &Option<usize>,
has_additional_headers: bool,
chunked_threshold: usize,
buffering: BufferingMode,
) -> TransferEncoding {
use crate::util;

Expand Down Expand Up @@ -166,6 +169,12 @@ fn choose_transfer_encoding(
return user_request;
}

// unbuffered messages must use chunked transfer encoding, to send messages as they
// are produced
if let BufferingMode::Unbuffered = buffering {
return TransferEncoding::Chunked;
}

// if we have additional headers, using chunked
if has_additional_headers {
return TransferEncoding::Chunked;
Expand Down Expand Up @@ -206,6 +215,7 @@ where
headers: Vec::with_capacity(16),
data_length,
chunked_threshold: None,
buffering: BufferingMode::Buffered,
};

for h in headers {
Expand All @@ -231,6 +241,14 @@ where
self
}

/// Define if the output should be buffered. If the output in unbuffered,
/// every write to the socket will be followed by a flush, and the output
/// will be chunked.
pub fn with_buffering(mut self, buffering: BufferingMode) -> Self {
self.buffering = buffering;
self
}

/// Convert the response into the underlying `Read` type.
///
/// This is mainly useful for testing as it must consume the `Response`.
Expand Down Expand Up @@ -319,6 +337,7 @@ where
status_code: self.status_code,
data_length,
chunked_threshold: self.chunked_threshold,
buffering: BufferingMode::Buffered,
}
}

Expand Down Expand Up @@ -346,6 +365,7 @@ where
&self.data_length,
false, /* TODO */
self.chunked_threshold(),
self.buffering,
));

// add `Date` if not in the headers
Expand Down Expand Up @@ -433,7 +453,11 @@ where
Some(TransferEncoding::Chunked) => {
use chunked_transfer::Encoder;

let mut writer = Encoder::new(writer);
let mut writer = if let BufferingMode::Unbuffered = self.buffering {
Encoder::with_flush_after_write(writer)
} else {
Encoder::new(writer)
};
io::copy(&mut reader, &mut writer)?;
}

Expand Down Expand Up @@ -481,6 +505,7 @@ where
headers: self.headers,
data_length: self.data_length,
chunked_threshold: self.chunked_threshold,
buffering: self.buffering,
}
}
}
Expand Down Expand Up @@ -569,6 +594,7 @@ impl Clone for Response<io::Empty> {
headers: self.headers.clone(),
data_length: self.data_length,
chunked_threshold: self.chunked_threshold,
buffering: self.buffering,
}
}
}
29 changes: 29 additions & 0 deletions src/util/buffering_wrapper.rs
@@ -0,0 +1,29 @@
use std::io::{BufWriter, Result as IoResult, Write};

pub enum MaybeBufferedWriter<W: Write> {
Buffered(BufWriter<W>),
Unbuffered(W),
}

impl<W: Write> Write for MaybeBufferedWriter<W> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self {
MaybeBufferedWriter::Buffered(w) => w.write(buf),
MaybeBufferedWriter::Unbuffered(w) => w.write(buf),
}
}

fn write_all(&mut self, buf: &[u8]) -> IoResult<()> {
match self {
MaybeBufferedWriter::Buffered(w) => w.write_all(buf),
MaybeBufferedWriter::Unbuffered(w) => w.write_all(buf),
}
}

fn flush(&mut self) -> IoResult<()> {
match self {
MaybeBufferedWriter::Buffered(w) => w.flush(),
MaybeBufferedWriter::Unbuffered(w) => w.flush(),
}
}
}
2 changes: 2 additions & 0 deletions src/util/mod.rs
@@ -1,3 +1,4 @@
pub use self::buffering_wrapper::MaybeBufferedWriter;
pub use self::custom_stream::CustomStream;
pub use self::equal_reader::EqualReader;
pub use self::fused_reader::FusedReader;
Expand All @@ -9,6 +10,7 @@ pub use self::task_pool::TaskPool;

use std::str::FromStr;

mod buffering_wrapper;
mod custom_stream;
mod equal_reader;
mod fused_reader;
Expand Down
36 changes: 12 additions & 24 deletions tests/promptness.rs
@@ -1,31 +1,17 @@
extern crate tiny_http;

use std::io::{copy, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::net::Shutdown;
use std::ops::Deref;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread::{sleep, spawn};
use std::thread::spawn;
use std::time::Duration;
use tiny_http::{Response, Server};
use tiny_http::Response;

/// Stream that produces bytes very slowly
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct SlowByteSrc {
val: u8,
len: usize,
}
impl<'b> Read for SlowByteSrc {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
sleep(Duration::from_millis(100));
let l = self.len.min(buf.len()).min(1000);
for v in buf[..l].iter_mut() {
*v = self.val;
}
self.len -= l;
Ok(l)
}
}
#[allow(dead_code)]
mod support;
use support::{new_one_server_one_client, SlowByteSrc};

/// crude impl of http `Transfer-Encoding: chunked`
fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) {
Expand All @@ -42,6 +28,8 @@ fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) {
}

mod prompt_pipelining {
use std::time::Duration;

use super::*;

/// Check that pipelined requests on the same connection are received promptly.
Expand All @@ -52,12 +40,12 @@ mod prompt_pipelining {
req_writer: impl FnOnce(&mut dyn Write) + Send + 'static,
) {
let resp_body = SlowByteSrc {
sleep_time: Duration::from_millis(100),
val: 42,
len: 1000_000,
}; // very slow response body

let server = Server::http("0.0.0.0:0").unwrap();
let mut client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap();
let (server, mut client) = new_one_server_one_client();
let (svr_send, svr_rcv) = channel();

spawn(move || {
Expand Down Expand Up @@ -142,8 +130,7 @@ mod prompt_responses {
timeout: Duration,
req_writer: impl FnOnce(&mut dyn Write) + Send + 'static,
) {
let server = Server::http("0.0.0.0:0").unwrap();
let client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap();
let (server, client) = new_one_server_one_client();

spawn(move || loop {
// server attempts to respond immediately
Expand All @@ -164,6 +151,7 @@ mod prompt_responses {
}

static SLOW_BODY: SlowByteSrc = SlowByteSrc {
sleep_time: Duration::from_millis(100),
val: 65,
len: 1000_000,
};
Expand Down