From 42211a82d14b33779a75070b9f5871a16a86b110 Mon Sep 17 00:00:00 2001 From: Guillaume Koenig Date: Sat, 26 Jan 2019 18:55:56 +0100 Subject: [PATCH 1/2] Add unblock method for graceful shutdown --- src/lib.rs | 13 +++++++++++-- src/util/messages_queue.rs | 37 +++++++++++++++++++++++++++++-------- tests/unblock-test.rs | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 tests/unblock-test.rs diff --git a/src/lib.rs b/src/lib.rs index a76307985..40b5491c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,6 +106,7 @@ extern crate openssl; use std::error::Error; use std::io::Error as IoError; +use std::io::ErrorKind as IoErrorKind; use std::io::Result as IoResult; use std::net; use std::net::{Shutdown, TcpStream, ToSocketAddrs}; @@ -383,8 +384,9 @@ impl Server { /// Blocks until an HTTP request has been submitted and returns it. pub fn recv(&self) -> IoResult { match self.messages.pop() { - Message::Error(err) => Err(err), - Message::NewRequest(rq) => Ok(rq), + Some(Message::Error(err)) => return Err(err), + Some(Message::NewRequest(rq)) => return Ok(rq), + None => return Err(IoError::new(IoErrorKind::Other, "thread unblocked")), } } @@ -405,6 +407,13 @@ impl Server { None => Ok(None), } } + + /// Unblock thread stuck in recv() or incoming_requests(). + /// If there are several such threads, only one is unblocked. + /// This method allows graceful shutdown of server. + pub fn unblock(&self) { + self.messages.unblock(); + } } impl<'a> Iterator for IncomingRequests<'a> { diff --git a/src/util/messages_queue.rs b/src/util/messages_queue.rs index dda91c212..9b95b174c 100644 --- a/src/util/messages_queue.rs +++ b/src/util/messages_queue.rs @@ -2,11 +2,16 @@ use std::collections::VecDeque; use std::sync::{Arc, Condvar, Mutex}; use std::time::{Duration, Instant}; +enum Control { + Elem(T), + Unblock, +} + pub struct MessagesQueue where T: Send, { - queue: Mutex>, + queue: Mutex>>, condvar: Condvar, } @@ -24,17 +29,27 @@ where /// Pushes an element to the queue. pub fn push(&self, value: T) { let mut queue = self.queue.lock().unwrap(); - queue.push_back(value); + queue.push_back(Control::Elem(value)); + self.condvar.notify_one(); + } + + /// Unblock one thread stuck in pop loop. + pub fn unblock(&self) { + let mut queue = self.queue.lock().unwrap(); + queue.push_back(Control::Unblock); self.condvar.notify_one(); } /// Pops an element. Blocks until one is available. - pub fn pop(&self) -> T { + /// Returns None in case unblock() was issued. + pub fn pop(&self) -> Option { let mut queue = self.queue.lock().unwrap(); loop { - if let Some(elem) = queue.pop_front() { - return elem; + match queue.pop_front() { + Some(Control::Elem(value)) => return Some(value), + Some(Control::Unblock) => return None, + None => (), } queue = self.condvar.wait(queue).unwrap(); @@ -44,17 +59,23 @@ where /// Tries to pop an element without blocking. pub fn try_pop(&self) -> Option { let mut queue = self.queue.lock().unwrap(); - queue.pop_front() + match queue.pop_front() { + Some(Control::Elem(value)) => Some(value), + Some(Control::Unblock) | None => None, + } } /// Tries to pop an element without blocking /// more than the specified timeout duration + /// or unblock() was issued pub fn pop_timeout(&self, timeout: Duration) -> Option { let mut queue = self.queue.lock().unwrap(); let mut duration = timeout; loop { - if let Some(elem) = queue.pop_front() { - return Some(elem); + match queue.pop_front() { + Some(Control::Elem(value)) => return Some(value), + Some(Control::Unblock) => return None, + None => (), } let now = Instant::now(); let (_queue, result) = self.condvar.wait_timeout(queue, timeout).unwrap(); diff --git a/tests/unblock-test.rs b/tests/unblock-test.rs new file mode 100644 index 000000000..001568a48 --- /dev/null +++ b/tests/unblock-test.rs @@ -0,0 +1,34 @@ +extern crate tiny_http; + +use std::sync::Arc; +use std::thread; + +#[test] +fn unblock_server() { + let server = tiny_http::Server::http("0.0.0.0:0").unwrap(); + let s = Arc::new(server); + + let s1 = s.clone(); + thread::spawn(move || s1.unblock()); + + // Without unblock this would hang forever + for _rq in s.incoming_requests() {} +} + +#[test] +fn unblock_threads() { + let server = tiny_http::Server::http("0.0.0.0:0").unwrap(); + let s = Arc::new(server); + + let s1 = s.clone(); + let s2 = s.clone(); + let h1 = thread::spawn(move || for _rq in s1.incoming_requests() {}); + let h2 = thread::spawn(move || for _rq in s2.incoming_requests() {}); + + // Graceful shutdown; removing even one of the + // unblock calls prevents termination + s.unblock(); + s.unblock(); + h1.join().unwrap(); + h2.join().unwrap(); +} From 9aa589e6e771c0a3ebe0914bc3a2a25fe3d42d98 Mon Sep 17 00:00:00 2001 From: Ulrik Date: Sat, 19 Sep 2020 19:06:13 +0200 Subject: [PATCH 2/2] response: Drop the use of EqualReader for TransferEncoding::Identity It's purpose is unclear, and it causes the entire reader to be consumed, even when client has disconnected and won't get the content. If the application needs to flush the reader for some side-effect, that can still be achieved by the application itself. --- src/response.rs | 5 +- tests/non-chunked-buffering.rs | 103 +++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tests/non-chunked-buffering.rs diff --git a/src/response.rs b/src/response.rs index 4eaf29a3e..1f80d78ce 100644 --- a/src/response.rs +++ b/src/response.rs @@ -419,14 +419,11 @@ where } Some(TransferEncoding::Identity) => { - use util::EqualReader; - assert!(data_length.is_some()); let data_length = data_length.unwrap(); if data_length >= 1 { - let (mut equ_reader, _) = EqualReader::new(reader.by_ref(), data_length); - io::copy(&mut equ_reader, &mut writer)?; + io::copy(&mut reader, &mut writer)?; } } diff --git a/tests/non-chunked-buffering.rs b/tests/non-chunked-buffering.rs new file mode 100644 index 000000000..4874a843e --- /dev/null +++ b/tests/non-chunked-buffering.rs @@ -0,0 +1,103 @@ +extern crate tiny_http; + +use std::io::{Cursor, Read, Write}; +use std::sync::{ + atomic::{ + AtomicUsize, + Ordering::{AcqRel, Acquire}, + }, + Arc, +}; + +#[allow(dead_code)] +mod support; + +struct MeteredReader { + inner: T, + position: Arc, +} + +impl Read for MeteredReader +where + T: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self.inner.read(buf) { + Ok(read) => { + self.position.fetch_add(read, AcqRel); + Ok(read) + } + e => e, + } + } +} + +type Reader = MeteredReader>; + +fn big_response_reader() -> Reader { + let big_body = "ABCDEFGHIJKLMNOPQRSTUVXYZ".repeat(1024 * 1024 * 16); + MeteredReader { + inner: Cursor::new(big_body), + position: Arc::new(AtomicUsize::new(0)), + } +} + +fn identity_served<'a>(r: &'a mut Reader) -> tiny_http::Response<&'a mut Reader> { + let body_len = r.inner.get_ref().len(); + tiny_http::Response::empty(200) + .with_chunked_threshold(usize::MAX) + .with_data(r, Some(body_len)) +} + +/// Checks that a body-Read:er is not called when the client has disconnected +#[test] +fn responding_to_closed_client() { + let (server, mut stream) = support::new_one_server_one_client(); + write!( + stream, + "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + .unwrap(); + + let request = server.recv().unwrap(); + + // Client already disconnected + drop(stream); + + let mut reader = big_response_reader(); + request + .respond(identity_served(&mut reader)) + .expect("Successful"); + + assert!(reader.position.load(Acquire) < 1024 * 1024); +} + +/// Checks that a slow client does not cause data to be consumed and buffered from a reader +#[test] +fn responding_to_non_consuming_client() { + let (server, mut stream) = support::new_one_server_one_client(); + write!( + stream, + "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ) + .unwrap(); + + let request = server.recv().unwrap(); + + let mut reader = big_response_reader(); + let position = reader.position.clone(); + + // Client still connected, but not reading anything + std::thread::spawn(move || { + request + .respond(identity_served(&mut reader)) + .expect("Successful"); + }); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + // It seems the client TCP socket can buffer quite a lot, so we need to be permissive + assert!(position.load(Acquire) < 8 * 1024 * 1024); + + drop(stream); +}