diff --git a/src/lib.rs b/src/lib.rs index a76307985..40b5491c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,6 +106,7 @@ extern crate openssl; use std::error::Error; use std::io::Error as IoError; +use std::io::ErrorKind as IoErrorKind; use std::io::Result as IoResult; use std::net; use std::net::{Shutdown, TcpStream, ToSocketAddrs}; @@ -383,8 +384,9 @@ impl Server { /// Blocks until an HTTP request has been submitted and returns it. pub fn recv(&self) -> IoResult { match self.messages.pop() { - Message::Error(err) => Err(err), - Message::NewRequest(rq) => Ok(rq), + Some(Message::Error(err)) => return Err(err), + Some(Message::NewRequest(rq)) => return Ok(rq), + None => return Err(IoError::new(IoErrorKind::Other, "thread unblocked")), } } @@ -405,6 +407,13 @@ impl Server { None => Ok(None), } } + + /// Unblock thread stuck in recv() or incoming_requests(). + /// If there are several such threads, only one is unblocked. + /// This method allows graceful shutdown of server. + pub fn unblock(&self) { + self.messages.unblock(); + } } impl<'a> Iterator for IncomingRequests<'a> { diff --git a/src/response.rs b/src/response.rs index ee0fdbc43..ac38f4faf 100644 --- a/src/response.rs +++ b/src/response.rs @@ -419,14 +419,11 @@ where } Some(TransferEncoding::Identity) => { - use util::EqualReader; - assert!(data_length.is_some()); let data_length = data_length.unwrap(); if data_length >= 1 { - let (mut equ_reader, _) = EqualReader::new(reader.by_ref(), data_length); - io::copy(&mut equ_reader, &mut writer)?; + io::copy(&mut reader, &mut writer)?; } } diff --git a/src/util/messages_queue.rs b/src/util/messages_queue.rs index dda91c212..9b95b174c 100644 --- a/src/util/messages_queue.rs +++ b/src/util/messages_queue.rs @@ -2,11 +2,16 @@ use std::collections::VecDeque; use std::sync::{Arc, Condvar, Mutex}; use std::time::{Duration, Instant}; +enum Control { + Elem(T), + Unblock, +} + pub struct MessagesQueue where T: Send, { - queue: Mutex>, + queue: Mutex>>, condvar: Condvar, } @@ -24,17 +29,27 @@ where /// Pushes an element to the queue. pub fn push(&self, value: T) { let mut queue = self.queue.lock().unwrap(); - queue.push_back(value); + queue.push_back(Control::Elem(value)); + self.condvar.notify_one(); + } + + /// Unblock one thread stuck in pop loop. + pub fn unblock(&self) { + let mut queue = self.queue.lock().unwrap(); + queue.push_back(Control::Unblock); self.condvar.notify_one(); } /// Pops an element. Blocks until one is available. - pub fn pop(&self) -> T { + /// Returns None in case unblock() was issued. + pub fn pop(&self) -> Option { let mut queue = self.queue.lock().unwrap(); loop { - if let Some(elem) = queue.pop_front() { - return elem; + match queue.pop_front() { + Some(Control::Elem(value)) => return Some(value), + Some(Control::Unblock) => return None, + None => (), } queue = self.condvar.wait(queue).unwrap(); @@ -44,17 +59,23 @@ where /// Tries to pop an element without blocking. pub fn try_pop(&self) -> Option { let mut queue = self.queue.lock().unwrap(); - queue.pop_front() + match queue.pop_front() { + Some(Control::Elem(value)) => Some(value), + Some(Control::Unblock) | None => None, + } } /// Tries to pop an element without blocking /// more than the specified timeout duration + /// or unblock() was issued pub fn pop_timeout(&self, timeout: Duration) -> Option { let mut queue = self.queue.lock().unwrap(); let mut duration = timeout; loop { - if let Some(elem) = queue.pop_front() { - return Some(elem); + match queue.pop_front() { + Some(Control::Elem(value)) => return Some(value), + Some(Control::Unblock) => return None, + None => (), } let now = Instant::now(); let (_queue, result) = self.condvar.wait_timeout(queue, timeout).unwrap(); diff --git a/tests/non-chunked-buffering.rs b/tests/non-chunked-buffering.rs new file mode 100644 index 000000000..4874a843e --- /dev/null +++ b/tests/non-chunked-buffering.rs @@ -0,0 +1,103 @@ +extern crate tiny_http; + +use std::io::{Cursor, Read, Write}; +use std::sync::{ + atomic::{ + AtomicUsize, + Ordering::{AcqRel, Acquire}, + }, + Arc, +}; + +#[allow(dead_code)] +mod support; + +struct MeteredReader { + inner: T, + position: Arc, +} + +impl Read for MeteredReader +where + T: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self.inner.read(buf) { + Ok(read) => { + self.position.fetch_add(read, AcqRel); + Ok(read) + } + e => e, + } + } +} + +type Reader = MeteredReader>; + +fn big_response_reader() -> Reader { + let big_body = "ABCDEFGHIJKLMNOPQRSTUVXYZ".repeat(1024 * 1024 * 16); + MeteredReader { + inner: Cursor::new(big_body), + position: Arc::new(AtomicUsize::new(0)), + } +} + +fn identity_served<'a>(r: &'a mut Reader) -> tiny_http::Response<&'a mut Reader> { + let body_len = r.inner.get_ref().len(); + tiny_http::Response::empty(200) + .with_chunked_threshold(usize::MAX) + .with_data(r, Some(body_len)) +} + +/// Checks that a body-Read:er is not called when the client has disconnected +#[test] +fn responding_to_closed_client() { + let (server, mut stream) = support::new_one_server_one_client(); + write!( + stream, + "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + .unwrap(); + + let request = server.recv().unwrap(); + + // Client already disconnected + drop(stream); + + let mut reader = big_response_reader(); + request + .respond(identity_served(&mut reader)) + .expect("Successful"); + + assert!(reader.position.load(Acquire) < 1024 * 1024); +} + +/// Checks that a slow client does not cause data to be consumed and buffered from a reader +#[test] +fn responding_to_non_consuming_client() { + let (server, mut stream) = support::new_one_server_one_client(); + write!( + stream, + "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + .unwrap(); + + let request = server.recv().unwrap(); + + let mut reader = big_response_reader(); + let position = reader.position.clone(); + + // Client still connected, but not reading anything + std::thread::spawn(move || { + request + .respond(identity_served(&mut reader)) + .expect("Successful"); + }); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + // It seems the client TCP socket can buffer quite a lot, so we need to be permissive + assert!(position.load(Acquire) < 8 * 1024 * 1024); + + drop(stream); +} diff --git a/tests/unblock-test.rs b/tests/unblock-test.rs new file mode 100644 index 000000000..001568a48 --- /dev/null +++ b/tests/unblock-test.rs @@ -0,0 +1,34 @@ +extern crate tiny_http; + +use std::sync::Arc; +use std::thread; + +#[test] +fn unblock_server() { + let server = tiny_http::Server::http("0.0.0.0:0").unwrap(); + let s = Arc::new(server); + + let s1 = s.clone(); + thread::spawn(move || s1.unblock()); + + // Without unblock this would hang forever + for _rq in s.incoming_requests() {} +} + +#[test] +fn unblock_threads() { + let server = tiny_http::Server::http("0.0.0.0:0").unwrap(); + let s = Arc::new(server); + + let s1 = s.clone(); + let s2 = s.clone(); + let h1 = thread::spawn(move || for _rq in s1.incoming_requests() {}); + let h2 = thread::spawn(move || for _rq in s2.incoming_requests() {}); + + // Graceful shutdown; removing even one of the + // unblock calls prevents termination + s.unblock(); + s.unblock(); + h1.join().unwrap(); + h2.join().unwrap(); +}