Skip to content

Commit

Permalink
Add support for unbuffered Response
Browse files Browse the repository at this point in the history
  • Loading branch information
nightmared committed Aug 25, 2022
1 parent f0fce7e commit ba24384
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 36 deletions.
12 changes: 10 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ascii::AsciiString;

use std::io::Error as IoError;
use std::io::Result as IoResult;
use std::io::Write;
use std::io::{BufReader, BufWriter, ErrorKind, Read};

use std::net::SocketAddr;
Expand All @@ -24,7 +25,7 @@ pub struct ClientConnection {

// sequence of Writers to the stream, to avoid writing response #2 before
// response #1
sink: SequentialWriterBuilder<BufWriter<RefinedTcpStream>>,
sink: SequentialWriterBuilder<Box<dyn Write + Send + Sync>>,

// Reader to read the next header from
next_header_source: SequentialReader<BufReader<RefinedTcpStream>>,
Expand All @@ -48,9 +49,12 @@ enum ReadError {

impl ClientConnection {
/// Creates a new `ClientConnection` that takes ownership of the `TcpStream`.
/// The buffered parameter allows for disabling buffering, which will
/// flush the writer after every single write.
pub fn new(
write_socket: RefinedTcpStream,
mut read_socket: RefinedTcpStream,
buffered: bool,
) -> ClientConnection {
let remote_addr = read_socket.peer_addr();
let secure = read_socket.secure();
Expand All @@ -60,7 +64,11 @@ impl ClientConnection {

ClientConnection {
source,
sink: SequentialWriterBuilder::new(BufWriter::with_capacity(1024, write_socket)),
sink: SequentialWriterBuilder::new(if buffered {
Box::new(BufWriter::with_capacity(1024, write_socket))
} else {
Box::new(write_socket)
}),
remote_addr,
next_header_source: first_header,
no_more_requests: false,
Expand Down
15 changes: 13 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ pub struct ServerConfig {

/// If `Some`, then the server will use SSL to encode the communications.
pub ssl: Option<SslConfig>,

/// If `false`, flush messages after every write
pub buffered_writer: bool,
}

/// Configuration of the server for SSL.
Expand All @@ -198,6 +201,7 @@ impl Server {
Server::new(ServerConfig {
addr: ConfigListenAddr::from_socket_addrs(addr)?,
ssl: None,
buffered_writer: true,
})
}

Expand All @@ -214,6 +218,7 @@ impl Server {
Server::new(ServerConfig {
addr: ConfigListenAddr::from_socket_addrs(addr)?,
ssl: Some(config),
buffered_writer: true,
})
}

Expand All @@ -226,13 +231,14 @@ impl Server {
Server::new(ServerConfig {
addr: ConfigListenAddr::unix_from_path(path),
ssl: None,
buffered_writer: true,
})
}

/// Builds a new server that listens on the specified address.
pub fn new(config: ServerConfig) -> Result<Server, Box<dyn Error + Send + Sync + 'static>> {
let listener = config.addr.bind()?;
Self::from_listener(listener, config.ssl)
Self::from_listener(listener, config.ssl, config.buffered_writer)
}

/// Builds a new server using the specified TCP listener.
Expand All @@ -242,6 +248,7 @@ impl Server {
pub fn from_listener<L: Into<Listener>>(
listener: L,
ssl_config: Option<SslConfig>,
buffered_writer: bool,
) -> Result<Server, Box<dyn Error + Send + Sync + 'static>> {
let listener = listener.into();
// building the "close" variable
Expand Down Expand Up @@ -311,7 +318,11 @@ impl Server {
Some(ref _ssl) => unreachable!(),
};

Ok(ClientConnection::new(write_closable, read_closable))
Ok(ClientConnection::new(
write_closable,
read_closable,
buffered_writer,
))
}
Err(e) => Err(e),
};
Expand Down
25 changes: 24 additions & 1 deletion src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct Response<R> {
headers: Vec<Header>,
data_length: Option<usize>,
chunked_threshold: Option<usize>,
buffered: bool,
}

/// A `Response` without a template parameter.
Expand Down Expand Up @@ -115,6 +116,7 @@ fn choose_transfer_encoding(
entity_length: &Option<usize>,
has_additional_headers: bool,
chunked_threshold: usize,
buffered: bool,
) -> TransferEncoding {
use crate::util;

Expand Down Expand Up @@ -165,6 +167,10 @@ fn choose_transfer_encoding(
return user_request;
}

if !buffered {
return TransferEncoding::Chunked;
}

// if we have additional headers, using chunked
if has_additional_headers {
return TransferEncoding::Chunked;
Expand Down Expand Up @@ -205,6 +211,7 @@ where
headers: Vec::with_capacity(16),
data_length,
chunked_threshold: None,
buffered: true,
};

for h in headers {
Expand All @@ -230,6 +237,14 @@ where
self
}

/// Define if the output should be buffered. If the output in unbuffered,
/// every write to the socket will be followed by a flush, and the output
/// will be chunked.
pub fn with_buffered(mut self, buffered: bool) -> Self {
self.buffered = buffered;
self
}

/// Convert the response into the underlying `Read` type.
///
/// This is mainly useful for testing as it must consume the `Response`.
Expand Down Expand Up @@ -318,6 +333,7 @@ where
status_code: self.status_code,
data_length,
chunked_threshold: self.chunked_threshold,
buffered: true,
}
}

Expand Down Expand Up @@ -345,6 +361,7 @@ where
&self.data_length,
false, /* TODO */
self.chunked_threshold(),
self.buffered,
));

// add `Date` if not in the headers
Expand Down Expand Up @@ -432,7 +449,11 @@ where
Some(TransferEncoding::Chunked) => {
use chunked_transfer::Encoder;

let mut writer = Encoder::new(writer);
let mut writer = if !self.buffered {
Encoder::with_flush_after_write(writer)
} else {
Encoder::new(writer)
};
io::copy(&mut reader, &mut writer)?;
}

Expand Down Expand Up @@ -480,6 +501,7 @@ where
headers: self.headers,
data_length: self.data_length,
chunked_threshold: self.chunked_threshold,
buffered: self.buffered,
}
}
}
Expand Down Expand Up @@ -568,6 +590,7 @@ impl Clone for Response<io::Empty> {
headers: self.headers.clone(),
data_length: self.data_length,
chunked_threshold: self.chunked_threshold,
buffered: self.buffered,
}
}
}
36 changes: 12 additions & 24 deletions tests/promptness.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,17 @@
extern crate tiny_http;

use std::io::{copy, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::net::Shutdown;
use std::ops::Deref;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread::{sleep, spawn};
use std::thread::spawn;
use std::time::Duration;
use tiny_http::{Response, Server};
use tiny_http::Response;

/// Stream that produces bytes very slowly
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct SlowByteSrc {
val: u8,
len: usize,
}
impl<'b> Read for SlowByteSrc {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
sleep(Duration::from_millis(100));
let l = self.len.min(buf.len()).min(1000);
for v in buf[..l].iter_mut() {
*v = self.val;
}
self.len -= l;
Ok(l)
}
}
#[allow(dead_code)]
mod support;
use support::{new_one_server_one_client, SlowByteSrc};

/// crude impl of http `Transfer-Encoding: chunked`
fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) {
Expand All @@ -42,6 +28,8 @@ fn encode_chunked(data: &mut dyn Read, output: &mut dyn Write) {
}

mod prompt_pipelining {
use std::time::Duration;

use super::*;

/// Check that pipelined requests on the same connection are received promptly.
Expand All @@ -52,12 +40,12 @@ mod prompt_pipelining {
req_writer: impl FnOnce(&mut dyn Write) + Send + 'static,
) {
let resp_body = SlowByteSrc {
sleep_time: Duration::from_millis(100),
val: 42,
len: 1000_000,
}; // very slow response body

let server = Server::http("0.0.0.0:0").unwrap();
let mut client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap();
let (server, mut client) = new_one_server_one_client();
let (svr_send, svr_rcv) = channel();

spawn(move || {
Expand Down Expand Up @@ -142,8 +130,7 @@ mod prompt_responses {
timeout: Duration,
req_writer: impl FnOnce(&mut dyn Write) + Send + 'static,
) {
let server = Server::http("0.0.0.0:0").unwrap();
let client = TcpStream::connect(server.server_addr().to_ip().unwrap()).unwrap();
let (server, client) = new_one_server_one_client();

spawn(move || loop {
// server attempts to respond immediately
Expand All @@ -164,6 +151,7 @@ mod prompt_responses {
}

static SLOW_BODY: SlowByteSrc = SlowByteSrc {
sleep_time: Duration::from_millis(100),
val: 65,
len: 1000_000,
};
Expand Down
91 changes: 88 additions & 3 deletions tests/simple-test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
extern crate tiny_http;

use std::io::{Read, Write};
use std::{
io::{Read, Write},
time::Duration,
};

#[allow(dead_code)]
mod support;
use chunked_transfer::Decoder;
use support::{new_one_server_one_client, new_one_server_one_client_unbuffered, SlowByteSrc};
use time::Instant;

#[test]
fn basic_handling() {
let (server, mut stream) = support::new_one_server_one_client();
let (server, mut stream) = new_one_server_one_client();
write!(
stream,
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
Expand All @@ -16,7 +22,7 @@ fn basic_handling() {

let request = server.recv().unwrap();
assert!(*request.method() == tiny_http::Method::Get);
//assert!(request.url() == "/");
assert!(request.url() == "/");
request
.respond(tiny_http::Response::from_string("hello world".to_owned()))
.unwrap();
Expand All @@ -27,3 +33,82 @@ fn basic_handling() {
stream.read_to_string(&mut content).unwrap();
assert!(content.ends_with("hello world"));
}

#[test]
fn unbuffered() {
let (server, mut stream) = new_one_server_one_client_unbuffered();
write!(
stream,
"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
)
.unwrap();

let request = server.recv().unwrap();
let start = Instant::now();
std::thread::spawn(|| {
request
.respond(
tiny_http::Response::new(
tiny_http::StatusCode(200),
Vec::new(),
SlowByteSrc {
sleep_time: Duration::from_millis(50),
val: 65,
// extreme length (100GB): we ensure that the data
// is streamed on demand instead of assembled in memory
len: 100_000_000_000,
},
None,
None,
)
.with_buffered(false),
)
.unwrap()
});

assert!(server.try_recv().unwrap().is_none());

let mut buf = [0; 64 * 1024];
let mut read = 0;

loop {
let nb_read = stream.read(&mut buf[read..]).unwrap();

// we should receive some data, but only a small amount because of the lack
// of buffering
assert!(nb_read > 0);
assert!(nb_read < buf.len());

read += nb_read;

// ensure that we receive the data quickly after the SlowByteSrc reader
// started feeeding the server
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(75) {
break;
}
}

let res = String::from_utf8(buf[..read].to_vec()).expect("Invalid UTF8 characters");
assert!(res.contains("Transfer-Encoding: chunked"));

let chunked_index = res
.find("\r\n\r\n")
.expect("Could not find the start of the chunked messages");
let mut chunked_data = res[chunked_index + 4..].to_string();
// emit the "end of stream" message
chunked_data.push_str("0\r\n\r\n");

// verify that we only received '\x65' characters
let mut decoder = Decoder::new(chunked_data.as_bytes());
let mut decoded = String::new();
decoder
.read_to_string(&mut decoded)
.expect("Invalid (non-chunked?) data");
let bytes = decoded.as_bytes();
let mut expected_vec = Vec::with_capacity(bytes.len());
for _ in 0..bytes.len() {
expected_vec.push(65);
}
assert!(bytes == expected_vec.as_slice());
}

0 comments on commit ba24384

Please sign in to comment.