diff --git a/.cargo/audit.toml b/.cargo/audit.toml new file mode 100644 index 00000000000..4fd083d9481 --- /dev/null +++ b/.cargo/audit.toml @@ -0,0 +1,7 @@ +# See https://github.com/rustsec/rustsec/blob/59e1d2ad0b9cbc6892c26de233d4925074b4b97b/cargo-audit/audit.toml.example for example. + +[advisories] +ignore = [ + # https://github.com/tokio-rs/tokio/issues/4177 + "RUSTSEC-2020-0159", +] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ff979cfb9ff..1d155a2b17b 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -20,7 +20,7 @@ serde = "1.0" serde_derive = "1.0" serde_json = "1.0" httparse = "1.0" -time = "0.1" +httpdate = "1.0" once_cell = "1.5.2" rand = "0.8.3" diff --git a/examples/tinyhttp.rs b/examples/tinyhttp.rs index e86305f367e..fa0bc6695e2 100644 --- a/examples/tinyhttp.rs +++ b/examples/tinyhttp.rs @@ -221,8 +221,9 @@ mod date { use std::cell::RefCell; use std::fmt::{self, Write}; use std::str; + use std::time::SystemTime; - use time::{self, Duration}; + use httpdate::HttpDate; pub struct Now(()); @@ -252,22 +253,26 @@ mod date { struct LastRenderedNow { bytes: [u8; 128], amt: usize, - next_update: time::Timespec, + unix_date: u64, } thread_local!(static LAST: RefCell = RefCell::new(LastRenderedNow { bytes: [0; 128], amt: 0, - next_update: time::Timespec::new(0, 0), + unix_date: 0, })); impl fmt::Display for Now { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { LAST.with(|cache| { let mut cache = cache.borrow_mut(); - let now = time::get_time(); - if now >= cache.next_update { - cache.update(now); + let now = SystemTime::now(); + let now_unix = now + .duration_since(SystemTime::UNIX_EPOCH) + .map(|since_epoch| since_epoch.as_secs()) + .unwrap_or(0); + if cache.unix_date != now_unix { + cache.update(now, now_unix); } f.write_str(cache.buffer()) }) @@ -279,11 +284,10 @@ mod date { str::from_utf8(&self.bytes[..self.amt]).unwrap() } - fn update(&mut self, now: time::Timespec) { + fn update(&mut self, now: SystemTime, now_unix: u64) { self.amt = 0; - write!(LocalBuffer(self), "{}", time::at(now).rfc822()).unwrap(); - self.next_update = now + Duration::seconds(1); - self.next_update.nsec = 0; + self.unix_date = now_unix; + write!(LocalBuffer(self), "{}", HttpDate::from(now)).unwrap(); } } diff --git a/tests-build/tests/fail/macros_core_no_default.stderr b/tests-build/tests/fail/macros_core_no_default.stderr index 6b3f8fa6c28..676acc8dbe3 100644 --- a/tests-build/tests/fail/macros_core_no_default.stderr +++ b/tests-build/tests/fail/macros_core_no_default.stderr @@ -4,4 +4,4 @@ error: The default runtime flavor is `multi_thread`, but the `rt-multi-thread` f 3 | #[tokio::main] | ^^^^^^^^^^^^^^ | - = note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the attribute macro `tokio::main` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests-build/tests/fail/macros_type_mismatch.rs b/tests-build/tests/fail/macros_type_mismatch.rs index 086244c9fd3..0a5b9c4c727 100644 --- a/tests-build/tests/fail/macros_type_mismatch.rs +++ b/tests-build/tests/fail/macros_type_mismatch.rs @@ -7,12 +7,6 @@ async fn missing_semicolon_or_return_type() { #[tokio::main] async fn missing_return_type() { - /* TODO(taiki-e): one of help messages still wrong - help: consider using a semicolon here - | - 16 | return Ok(());; - | - */ return Ok(()); } @@ -21,9 +15,9 @@ async fn extra_semicolon() -> Result<(), ()> { /* TODO(taiki-e): help message still wrong help: try using a variant of the expected enum | - 29 | Ok(Ok(());) + 23 | Ok(Ok(());) | - 29 | Err(Ok(());) + 23 | Err(Ok(());) | */ Ok(()); diff --git a/tests-build/tests/fail/macros_type_mismatch.stderr b/tests-build/tests/fail/macros_type_mismatch.stderr index 4d573181523..a8fa99bc63b 100644 --- a/tests-build/tests/fail/macros_type_mismatch.stderr +++ b/tests-build/tests/fail/macros_type_mismatch.stderr @@ -9,43 +9,37 @@ error[E0308]: mismatched types help: consider using a semicolon here | 5 | Ok(()); - | ^ + | + help: try adding a return type | 4 | async fn missing_semicolon_or_return_type() -> Result<(), _> { - | ^^^^^^^^^^^^^^^^ + | ++++++++++++++++ error[E0308]: mismatched types - --> $DIR/macros_type_mismatch.rs:16:5 + --> $DIR/macros_type_mismatch.rs:10:5 | -16 | return Ok(()); +9 | async fn missing_return_type() { + | - help: try adding a return type: `-> Result<(), _>` +10 | return Ok(()); | ^^^^^^^^^^^^^^ expected `()`, found enum `Result` | = note: expected unit type `()` found enum `Result<(), _>` -help: consider using a semicolon here - | -16 | return Ok(());; - | ^ -help: try adding a return type - | -9 | async fn missing_return_type() -> Result<(), _> { - | ^^^^^^^^^^^^^^^^ error[E0308]: mismatched types - --> $DIR/macros_type_mismatch.rs:29:5 + --> $DIR/macros_type_mismatch.rs:23:5 | -20 | async fn extra_semicolon() -> Result<(), ()> { +14 | async fn extra_semicolon() -> Result<(), ()> { | -------------- expected `Result<(), ()>` because of return type ... -29 | Ok(()); +23 | Ok(()); | ^^^^^^^ expected enum `Result`, found `()` | = note: expected enum `Result<(), ()>` found unit type `()` help: try using a variant of the expected enum | -29 | Ok(Ok(());) +23 | Ok(Ok(());) | -29 | Err(Ok(());) +23 | Err(Ok(());) | diff --git a/tests-build/tests/macros.rs b/tests-build/tests/macros.rs index d2330cb6e18..0a180dfb74f 100644 --- a/tests-build/tests/macros.rs +++ b/tests-build/tests/macros.rs @@ -5,6 +5,12 @@ fn compile_fail_full() { #[cfg(feature = "full")] t.pass("tests/pass/forward_args_and_output.rs"); + #[cfg(feature = "full")] + t.pass("tests/pass/macros_main_return.rs"); + + #[cfg(feature = "full")] + t.pass("tests/pass/macros_main_loop.rs"); + #[cfg(feature = "full")] t.compile_fail("tests/fail/macros_invalid_input.rs"); diff --git a/tests-build/tests/macros_clippy.rs b/tests-build/tests/macros_clippy.rs new file mode 100644 index 00000000000..0f3f4bb0b8b --- /dev/null +++ b/tests-build/tests/macros_clippy.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "full")] +#[tokio::test] +async fn test_with_semicolon_without_return_type() { + #![deny(clippy::semicolon_if_nothing_returned)] + + dbg!(0); +} diff --git a/tests-build/tests/pass/macros_main_loop.rs b/tests-build/tests/pass/macros_main_loop.rs new file mode 100644 index 00000000000..d7d51982c36 --- /dev/null +++ b/tests-build/tests/pass/macros_main_loop.rs @@ -0,0 +1,14 @@ +use tests_build::tokio; + +#[tokio::main] +async fn main() -> Result<(), ()> { + loop { + if !never() { + return Ok(()); + } + } +} + +fn never() -> bool { + std::time::Instant::now() > std::time::Instant::now() +} diff --git a/tests-build/tests/pass/macros_main_return.rs b/tests-build/tests/pass/macros_main_return.rs new file mode 100644 index 00000000000..d4d34ec26d3 --- /dev/null +++ b/tests-build/tests/pass/macros_main_return.rs @@ -0,0 +1,6 @@ +use tests_build::tokio; + +#[tokio::main] +async fn main() -> Result<(), ()> { + return Ok(()); +} diff --git a/tokio-macros/src/entry.rs b/tokio-macros/src/entry.rs index ddc19585d7b..01f8ee4c1eb 100644 --- a/tokio-macros/src/entry.rs +++ b/tokio-macros/src/entry.rs @@ -1,6 +1,10 @@ use proc_macro::TokenStream; use proc_macro2::Span; use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::Parser; + +// syn::AttributeArgs does not implement syn::Parse +type AttributeArgs = syn::punctuated::Punctuated; #[derive(Clone, Copy, PartialEq)] enum RuntimeFlavor { @@ -27,6 +31,13 @@ struct FinalConfig { start_paused: Option, } +/// Config used in case of the attribute not being able to build a valid config +const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { + flavor: RuntimeFlavor::CurrentThread, + worker_threads: None, + start_paused: None, +}; + struct Configuration { rt_multi_thread_available: bool, default_flavor: RuntimeFlavor, @@ -184,13 +195,13 @@ fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result Result { - if input.sig.asyncness.take().is_none() { +) -> Result { + if input.sig.asyncness.is_none() { let msg = "the `async` keyword is missing from the function declaration"; return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); } @@ -201,12 +212,15 @@ fn parse_knobs( for arg in args { match arg { syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => { - let ident = namevalue.path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return Err(syn::Error::new_spanned(namevalue, msg)); - } - match ident.unwrap().to_string().to_lowercase().as_str() { + let ident = namevalue + .path + .get_ident() + .ok_or_else(|| { + syn::Error::new_spanned(&namevalue, "Must have specified ident") + })? + .to_string() + .to_lowercase(); + match ident.as_str() { "worker_threads" => { config.set_worker_threads( namevalue.lit.clone(), @@ -239,12 +253,11 @@ fn parse_knobs( } } syn::NestedMeta::Meta(syn::Meta::Path(path)) => { - let ident = path.get_ident(); - if ident.is_none() { - let msg = "Must have specified ident"; - return Err(syn::Error::new_spanned(path, msg)); - } - let name = ident.unwrap().to_string().to_lowercase(); + let name = path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? + .to_string() + .to_lowercase(); let msg = match name.as_str() { "threaded_scheduler" | "multi_thread" => { format!( @@ -276,7 +289,11 @@ fn parse_knobs( } } - let config = config.build()?; + config.build() +} + +fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { + input.sig.asyncness = None; // If type mismatch occurs, the current rustc points to the last statement. let (last_stmt_start_span, last_stmt_end_span) = { @@ -321,16 +338,32 @@ fn parse_knobs( let body = &input.block; let brace_token = input.block.brace_token; + let (tail_return, tail_semicolon) = match body.stmts.last() { + Some(syn::Stmt::Semi(expr, _)) => match expr { + syn::Expr::Return(_) => (quote! { return }, quote! { ; }), + _ => match &input.sig.output { + syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => + { + (quote! {}, quote! { ; }) // unit + } + syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit + syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another + }, + }, + _ => (quote! {}, quote! {}), + }; input.block = syn::parse2(quote_spanned! {last_stmt_end_span=> { - #rt + let body = async #body; + #[allow(clippy::expect_used)] + #tail_return #rt .enable_all() .build() - .unwrap() - .block_on(async #body) + .expect("Failed building the Runtime") + .block_on(body)#tail_semicolon } }) - .unwrap(); + .expect("Parsing failure"); input.block.brace_token = brace_token; let result = quote! { @@ -338,36 +371,58 @@ fn parse_knobs( #input }; - Ok(result.into()) + result.into() +} + +fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(TokenStream::from(error.into_compile_error())); + tokens } #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { - let input = syn::parse_macro_input!(item as syn::ItemFn); - let args = syn::parse_macro_input!(args as syn::AttributeArgs); + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: syn::ItemFn = match syn::parse(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; - if input.sig.ident == "main" && !input.sig.inputs.is_empty() { + let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { let msg = "the main function cannot accept arguments"; - return syn::Error::new_spanned(&input.sig.ident, msg) - .to_compile_error() - .into(); - } + Err(syn::Error::new_spanned(&input.sig.ident, msg)) + } else { + AttributeArgs::parse_terminated + .parse(args) + .and_then(|args| build_config(input.clone(), args, false, rt_multi_thread)) + }; - parse_knobs(input, args, false, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) + match config { + Ok(config) => parse_knobs(input, false, config), + Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e), + } } pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { - let input = syn::parse_macro_input!(item as syn::ItemFn); - let args = syn::parse_macro_input!(args as syn::AttributeArgs); - - for attr in &input.attrs { - if attr.path.is_ident("test") { - let msg = "second test attribute is supplied"; - return syn::Error::new_spanned(&attr, msg) - .to_compile_error() - .into(); - } - } + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: syn::ItemFn = match syn::parse(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) { + let msg = "second test attribute is supplied"; + Err(syn::Error::new_spanned(&attr, msg)) + } else { + AttributeArgs::parse_terminated + .parse(args) + .and_then(|args| build_config(input.clone(), args, true, rt_multi_thread)) + }; - parse_knobs(input, args, true, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into()) + match config { + Ok(config) => parse_knobs(input, true, config), + Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e), + } } diff --git a/tokio/CHANGELOG.md b/tokio/CHANGELOG.md index b2d256c9ecb..806d440766b 100644 --- a/tokio/CHANGELOG.md +++ b/tokio/CHANGELOG.md @@ -1,3 +1,34 @@ +# 1.8.4 (November 15, 2021) + +This release backports a bug fix from 1.13.1. + +### Fixed + +- sync: fix a data race between `oneshot::Sender::send` and awaiting a + `oneshot::Receiver` when the oneshot has been closed ([#4226]) + +[#4226]: https://github.com/tokio-rs/tokio/pull/4226 + +# 1.8.3 (July 26, 2021) + +This release backports two fixes from 1.9.0 + +### Fixed + + - Fix leak if output of future panics on drop ([#3967]) + - Fix leak in `LocalSet` ([#3978]) + +[#3967]: https://github.com/tokio-rs/tokio/pull/3967 +[#3978]: https://github.com/tokio-rs/tokio/pull/3978 + +# 1.8.2 (July 19, 2021) + +Fixes a missed edge case from 1.8.1. + +### Fixed + +- runtime: drop canceled future on next poll (#3965) + # 1.8.1 (July 6, 2021) Forward ports 1.5.1 fixes. diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 2945b6a387b..26f1a9ddfb2 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -7,12 +7,12 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v1.0.x" git tag. -version = "1.8.1" +version = "1.8.4" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/1.8.1/tokio/" +documentation = "https://docs.rs/tokio/1.8.4/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -109,7 +109,7 @@ signal-hook-registry = { version = "1.1.1", optional = true } [target.'cfg(unix)'.dev-dependencies] libc = { version = "0.2.42" } -nix = { version = "0.19.0" } +nix = { version = "0.22.0" } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" @@ -123,6 +123,7 @@ version = "0.3.6" tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } futures = { version = "0.3.0", features = ["async-await"] } +mockall = "0.10.2" proptest = "1" rand = "0.8.0" tempfile = "3.1.0" diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 5c06e732b09..5286e6c5c5d 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -3,7 +3,7 @@ //! [`File`]: File use self::State::*; -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use crate::io::blocking::Buf; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use crate::sync::Mutex; @@ -19,6 +19,19 @@ use std::task::Context; use std::task::Poll; use std::task::Poll::*; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(test)] +use super::mocks::MockFile as StdFile; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; +#[cfg(not(test))] +use std::fs::File as StdFile; + /// A reference to an open file on the filesystem. /// /// This is a specialized version of [`std::fs::File`][std] for usage from the @@ -78,7 +91,7 @@ use std::task::Poll::*; /// # } /// ``` pub struct File { - std: Arc, + std: Arc, inner: Mutex, } @@ -96,7 +109,7 @@ struct Inner { #[derive(Debug)] enum State { Idle(Option), - Busy(sys::Blocking<(Operation, Buf)>), + Busy(JoinHandle<(Operation, Buf)>), } #[derive(Debug)] @@ -142,7 +155,7 @@ impl File { /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn open(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); - let std = asyncify(|| sys::File::open(path)).await?; + let std = asyncify(|| StdFile::open(path)).await?; Ok(File::from_std(std)) } @@ -182,7 +195,7 @@ impl File { /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn create(path: impl AsRef) -> io::Result { let path = path.as_ref().to_owned(); - let std_file = asyncify(move || sys::File::create(path)).await?; + let std_file = asyncify(move || StdFile::create(path)).await?; Ok(File::from_std(std_file)) } @@ -199,7 +212,7 @@ impl File { /// let std_file = std::fs::File::open("foo.txt").unwrap(); /// let file = tokio::fs::File::from_std(std_file); /// ``` - pub fn from_std(std: sys::File) -> File { + pub fn from_std(std: StdFile) -> File { File { std: Arc::new(std), inner: Mutex::new(Inner { @@ -323,7 +336,7 @@ impl File { let std = self.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -409,7 +422,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub async fn into_std(mut self) -> sys::File { + pub async fn into_std(mut self) -> StdFile { self.inner.get_mut().complete_inflight().await; Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") } @@ -434,7 +447,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub fn try_into_std(mut self) -> Result { + pub fn try_into_std(mut self) -> Result { match Arc::try_unwrap(self.std) { Ok(file) => Ok(file), Err(std_file_arc) => { @@ -502,7 +515,7 @@ impl AsyncRead for File { buf.ensure_capacity_for(dst); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = buf.read_from(&mut &*std); (Operation::Read(res), buf) })); @@ -569,7 +582,7 @@ impl AsyncSeek for File { let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) })); @@ -636,7 +649,7 @@ impl AsyncWrite for File { let n = buf.copy_from(src); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -685,8 +698,8 @@ impl AsyncWrite for File { } } -impl From for File { - fn from(std: sys::File) -> Self { +impl From for File { + fn from(std: StdFile) -> Self { Self::from_std(std) } } @@ -709,7 +722,7 @@ impl std::os::unix::io::AsRawFd for File { #[cfg(unix)] impl std::os::unix::io::FromRawFd for File { unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self { - sys::File::from_raw_fd(fd).into() + StdFile::from_raw_fd(fd).into() } } @@ -723,7 +736,7 @@ impl std::os::windows::io::AsRawHandle for File { #[cfg(windows)] impl std::os::windows::io::FromRawHandle for File { unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self { - sys::File::from_raw_handle(handle).into() + StdFile::from_raw_handle(handle).into() } } @@ -756,3 +769,6 @@ impl Inner { } } } + +#[cfg(test)] +mod tests; diff --git a/tokio/tests/fs_file_mocked.rs b/tokio/src/fs/file/tests.rs similarity index 55% rename from tokio/tests/fs_file_mocked.rs rename to tokio/src/fs/file/tests.rs index 77715327d8a..28b5ffe77af 100644 --- a/tokio/tests/fs_file_mocked.rs +++ b/tokio/src/fs/file/tests.rs @@ -1,80 +1,21 @@ -#![warn(rust_2018_idioms)] -#![cfg(feature = "full")] - -macro_rules! ready { - ($e:expr $(,)?) => { - match $e { - std::task::Poll::Ready(t) => t, - std::task::Poll::Pending => return std::task::Poll::Pending, - } - }; -} - -#[macro_export] -macro_rules! cfg_fs { - ($($item:item)*) => { $($item)* } -} - -#[macro_export] -macro_rules! cfg_io_std { - ($($item:item)*) => { $($item)* } -} - -use futures::future; - -// Load source -#[allow(warnings)] -#[path = "../src/fs/file.rs"] -mod file; -use file::File; - -#[allow(warnings)] -#[path = "../src/io/blocking.rs"] -mod blocking; - -// Load mocked types -mod support { - pub(crate) mod mock_file; - pub(crate) mod mock_pool; -} -pub(crate) use support::mock_pool as pool; - -// Place them where the source expects them -pub(crate) mod io { - pub(crate) use tokio::io::*; - - pub(crate) use crate::blocking; - - pub(crate) mod sys { - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } -} -pub(crate) mod fs { - pub(crate) mod sys { - pub(crate) use crate::support::mock_file::File; - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } - - pub(crate) use crate::support::mock_pool::asyncify; -} -pub(crate) mod sync { - pub(crate) use tokio::sync::Mutex; -} -use fs::sys; - -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; -use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; - -use std::io::SeekFrom; +use super::*; +use crate::{ + fs::mocks::*, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; +use mockall::{predicate::eq, Sequence}; +use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task}; const HELLO: &[u8] = b"hello world..."; const FOO: &[u8] = b"foo bar baz..."; #[test] fn open_read() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -83,12 +24,10 @@ fn open_read() { assert_eq!(0, pool::len()); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); let n = assert_ready_ok!(t.poll()); @@ -98,9 +37,11 @@ fn open_read() { #[test] fn read_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -120,8 +61,11 @@ fn read_twice_before_dispatch() { #[test] fn read_with_smaller_buf() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -153,8 +97,22 @@ fn read_with_smaller_buf() { #[test] fn read_with_bigger_buf() { - let (mock, file) = sys::File::mock(); - mock.read(&HELLO[..4]).read(&HELLO[4..]); + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..4].copy_from_slice(&HELLO[..4]); + Ok(4) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len() - 4].copy_from_slice(&HELLO[4..]); + Ok(HELLO.len() - 4) + }); let mut file = File::from_std(file); @@ -194,8 +152,19 @@ fn read_with_bigger_buf() { #[test] fn read_err_then_read_success() { - let (mock, file) = sys::File::mock(); - mock.read_err().read(&HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -225,8 +194,11 @@ fn read_err_then_read_success() { #[test] fn open_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -235,12 +207,10 @@ fn open_write() { assert_eq!(0, pool::len()); assert_ready_ok!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(!t.is_woken()); let mut t = task::spawn(file.flush()); @@ -249,7 +219,7 @@ fn open_write() { #[test] fn flush_while_idle() { - let (_mock, file) = sys::File::mock(); + let file = MockFile::default(); let mut file = File::from_std(file); @@ -271,13 +241,42 @@ fn read_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.read(&data[0..chunk_a]) - .read(&data[chunk_a..chunk_b]) - .read(&data[chunk_b..chunk_c]) - .read(&data[chunk_c..]); - + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[0..chunk_a].copy_from_slice(&d0[0..chunk_a]); + Ok(chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d1[chunk_a..chunk_b]); + Ok(chunk_b - chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d2[chunk_b..chunk_c]); + Ok(chunk_c - chunk_b) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a - 1].copy_from_slice(&d3[chunk_c..]); + Ok(chunk_a - 1) + }); let mut file = File::from_std(file); let mut actual = vec![0; chunk_d]; @@ -296,8 +295,7 @@ fn read_with_buffer_larger_than_max() { pos += n; } - assert_eq!(mock.remaining(), 0); - assert_eq!(data, &actual[..data.len()]); + assert_eq!(&data[..], &actual[..data.len()]); } #[test] @@ -314,12 +312,34 @@ fn write_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.write(&data[0..chunk_a]) - .write(&data[chunk_a..chunk_b]) - .write(&data[chunk_b..chunk_c]) - .write(&data[chunk_c..]); + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d0[0..chunk_a]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d1[chunk_a..chunk_b]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d2[chunk_b..chunk_c]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d3[chunk_c..chunk_d - 1]) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -344,14 +364,22 @@ fn write_with_buffer_larger_than_max() { } pool::run_one(); - - assert_eq!(mock.remaining(), 0); } #[test] fn write_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -380,10 +408,24 @@ fn write_twice_before_dispatch() { #[test] fn incomplete_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -406,8 +448,25 @@ fn incomplete_read_followed_by_write() { #[test] fn incomplete_partial_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO).seek_current_ok(-10, 0).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-10))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -433,10 +492,25 @@ fn incomplete_partial_read_followed_by_write() { #[test] fn incomplete_read_followed_by_flush() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -458,8 +532,18 @@ fn incomplete_read_followed_by_flush() { #[test] fn incomplete_flush_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -484,8 +568,10 @@ fn incomplete_flush_followed_by_write() { #[test] fn read_err() { - let (mock, file) = sys::File::mock(); - mock.read_err(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -502,8 +588,10 @@ fn read_err() { #[test] fn write_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err(); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -518,8 +606,19 @@ fn write_write_err() { #[test] fn write_read_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -541,8 +640,19 @@ fn write_read_write_err() { #[test] fn write_read_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -564,8 +674,17 @@ fn write_read_flush_err() { #[test] fn write_seek_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -587,8 +706,17 @@ fn write_seek_write_err() { #[test] fn write_seek_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -610,8 +738,14 @@ fn write_seek_flush_err() { #[test] fn sync_all_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -635,8 +769,16 @@ fn sync_all_ordered_after_write() { #[test] fn sync_all_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -660,8 +802,14 @@ fn sync_all_err_ordered_after_write() { #[test] fn sync_data_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -685,8 +833,16 @@ fn sync_data_ordered_after_write() { #[test] fn sync_data_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -710,17 +866,15 @@ fn sync_data_err_ordered_after_write() { #[test] fn open_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.set_len(123); + let mut file = MockFile::default(); + file.expect_set_len().with(eq(123)).returning(|_| Ok(())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_ok!(t.poll()); @@ -728,17 +882,17 @@ fn open_set_len_ok() { #[test] fn open_set_len_err() { - let (mock, file) = sys::File::mock(); - mock.set_len_err(123); + let mut file = MockFile::default(); + file.expect_set_len() + .with(eq(123)) + .returning(|_| Err(io::ErrorKind::Other.into())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_err!(t.poll()); @@ -746,11 +900,32 @@ fn open_set_len_err() { #[test] fn partial_read_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-14, 0) - .set_len(123) - .read(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_set_len() + .once() + .in_sequence(&mut seq) + .with(eq(123)) + .returning(|_| Ok(())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..FOO.len()].copy_from_slice(FOO); + Ok(FOO.len()) + }); let mut buf = [0; 32]; let mut file = File::from_std(file); diff --git a/tokio/src/fs/mocks.rs b/tokio/src/fs/mocks.rs new file mode 100644 index 00000000000..68ef4f3a7a4 --- /dev/null +++ b/tokio/src/fs/mocks.rs @@ -0,0 +1,136 @@ +//! Mock version of std::fs::File; +use mockall::mock; + +use crate::sync::oneshot; +use std::{ + cell::RefCell, + collections::VecDeque, + fs::{Metadata, Permissions}, + future::Future, + io::{self, Read, Seek, SeekFrom, Write}, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +mock! { + #[derive(Debug)] + pub File { + pub fn create(pb: PathBuf) -> io::Result; + // These inner_ methods exist because std::fs::File has two + // implementations for each of these methods: one on "&mut self" and + // one on "&&self". Defining both of those in terms of an inner_ method + // allows us to specify the expectation the same way, regardless of + // which method is used. + pub fn inner_flush(&self) -> io::Result<()>; + pub fn inner_read(&self, dst: &mut [u8]) -> io::Result; + pub fn inner_seek(&self, pos: SeekFrom) -> io::Result; + pub fn inner_write(&self, src: &[u8]) -> io::Result; + pub fn metadata(&self) -> io::Result; + pub fn open(pb: PathBuf) -> io::Result; + pub fn set_len(&self, size: u64) -> io::Result<()>; + pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>; + pub fn sync_all(&self) -> io::Result<()>; + pub fn sync_data(&self) -> io::Result<()>; + pub fn try_clone(&self) -> io::Result; + } + #[cfg(windows)] + impl std::os::windows::io::AsRawHandle for File { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle; + } + #[cfg(windows)] + impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(h: std::os::windows::io::RawHandle) -> Self; + } + #[cfg(unix)] + impl std::os::unix::io::AsRawFd for File { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd; + } + + #[cfg(unix)] + impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(h: std::os::unix::io::RawFd) -> Self; + } +} + +impl Read for MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result { + self.inner_read(dst) + } +} + +impl Read for &'_ MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result { + self.inner_read(dst) + } +} + +impl Seek for &'_ MockFile { + fn seek(&mut self, pos: SeekFrom) -> io::Result { + self.inner_seek(pos) + } +} + +impl Write for &'_ MockFile { + fn write(&mut self, src: &[u8]) -> io::Result { + self.inner_write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner_flush() + } +} + +thread_local! { + static QUEUE: RefCell>> = RefCell::new(VecDeque::new()) +} + +#[derive(Debug)] +pub(super) struct JoinHandle { + rx: oneshot::Receiver, +} + +pub(super) fn spawn_blocking(f: F) -> JoinHandle +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + JoinHandle { rx } +} + +impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use std::task::Poll::*; + + match Pin::new(&mut self.rx).poll(cx) { + Ready(Ok(v)) => Ready(Ok(v)), + Ready(Err(e)) => panic!("error = {:?}", e), + Pending => Pending, + } + } +} + +pub(super) mod pool { + use super::*; + + pub(in super::super) fn len() -> usize { + QUEUE.with(|cell| cell.borrow().len()) + } + + pub(in super::super) fn run_one() { + let task = QUEUE + .with(|cell| cell.borrow_mut().pop_front()) + .expect("expected task to run, but none ready"); + + task(); + } +} diff --git a/tokio/src/fs/mod.rs b/tokio/src/fs/mod.rs index d4f00749028..ca0264b3678 100644 --- a/tokio/src/fs/mod.rs +++ b/tokio/src/fs/mod.rs @@ -84,6 +84,9 @@ pub use self::write::write; mod copy; pub use self::copy::copy; +#[cfg(test)] +mod mocks; + feature! { #![unix] @@ -103,12 +106,17 @@ feature! { use std::io; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(test)] +use mocks::spawn_blocking; + pub(crate) async fn asyncify(f: F) -> io::Result where F: FnOnce() -> io::Result + Send + 'static, T: Send + 'static, { - match sys::run(f).await { + match spawn_blocking(f).await { Ok(res) => res, Err(_) => Err(io::Error::new( io::ErrorKind::Other, @@ -116,12 +124,3 @@ where )), } } - -/// Types in this module can be mocked out in tests. -mod sys { - pub(crate) use std::fs::File; - - // TODO: don't rename - pub(crate) use crate::blocking::spawn_blocking as run; - pub(crate) use crate::blocking::JoinHandle as Blocking; -} diff --git a/tokio/src/fs/open_options.rs b/tokio/src/fs/open_options.rs index fa37a60dff6..3e73529ecf6 100644 --- a/tokio/src/fs/open_options.rs +++ b/tokio/src/fs/open_options.rs @@ -3,6 +3,13 @@ use crate::fs::{asyncify, File}; use std::io; use std::path::Path; +#[cfg(test)] +mod mock_open_options; +#[cfg(test)] +use mock_open_options::MockOpenOptions as StdOpenOptions; +#[cfg(not(test))] +use std::fs::OpenOptions as StdOpenOptions; + /// Options and flags which can be used to configure how a file is opened. /// /// This builder exposes the ability to configure how a [`File`] is opened and @@ -69,7 +76,7 @@ use std::path::Path; /// } /// ``` #[derive(Clone, Debug)] -pub struct OpenOptions(std::fs::OpenOptions); +pub struct OpenOptions(StdOpenOptions); impl OpenOptions { /// Creates a blank new set of options ready for configuration. @@ -89,7 +96,7 @@ impl OpenOptions { /// let future = options.read(true).open("foo.txt"); /// ``` pub fn new() -> OpenOptions { - OpenOptions(std::fs::OpenOptions::new()) + OpenOptions(StdOpenOptions::new()) } /// Sets the option for read access. @@ -384,7 +391,7 @@ impl OpenOptions { } /// Returns a mutable reference to the underlying `std::fs::OpenOptions` - pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { + pub(super) fn as_inner_mut(&mut self) -> &mut StdOpenOptions { &mut self.0 } } @@ -645,8 +652,8 @@ feature! { } } -impl From for OpenOptions { - fn from(options: std::fs::OpenOptions) -> OpenOptions { +impl From for OpenOptions { + fn from(options: StdOpenOptions) -> OpenOptions { OpenOptions(options) } } diff --git a/tokio/src/fs/open_options/mock_open_options.rs b/tokio/src/fs/open_options/mock_open_options.rs new file mode 100644 index 00000000000..cbbda0ec256 --- /dev/null +++ b/tokio/src/fs/open_options/mock_open_options.rs @@ -0,0 +1,38 @@ +//! Mock version of std::fs::OpenOptions; +use mockall::mock; + +use crate::fs::mocks::MockFile; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(windows)] +use std::os::windows::fs::OpenOptionsExt; +use std::{io, path::Path}; + +mock! { + #[derive(Debug)] + pub OpenOptions { + pub fn append(&mut self, append: bool) -> &mut Self; + pub fn create(&mut self, create: bool) -> &mut Self; + pub fn create_new(&mut self, create_new: bool) -> &mut Self; + pub fn open + 'static>(&self, path: P) -> io::Result; + pub fn read(&mut self, read: bool) -> &mut Self; + pub fn truncate(&mut self, truncate: bool) -> &mut Self; + pub fn write(&mut self, write: bool) -> &mut Self; + } + impl Clone for OpenOptions { + fn clone(&self) -> Self; + } + #[cfg(unix)] + impl OpenOptionsExt for OpenOptions { + fn custom_flags(&mut self, flags: i32) -> &mut Self; + fn mode(&mut self, mode: u32) -> &mut Self; + } + #[cfg(windows)] + impl OpenOptionsExt for OpenOptions { + fn access_mode(&mut self, access: u32) -> &mut Self; + fn share_mode(&mut self, val: u32) -> &mut Self; + fn custom_flags(&mut self, flags: u32) -> &mut Self; + fn attributes(&mut self, val: u32) -> &mut Self; + fn security_qos_flags(&mut self, flags: u32) -> &mut Self; + } +} diff --git a/tokio/src/fs/read_dir.rs b/tokio/src/fs/read_dir.rs index aedaf7b921e..c1cb665dee3 100644 --- a/tokio/src/fs/read_dir.rs +++ b/tokio/src/fs/read_dir.rs @@ -1,4 +1,4 @@ -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use std::ffi::OsString; use std::fs::{FileType, Metadata}; @@ -10,6 +10,15 @@ use std::sync::Arc; use std::task::Context; use std::task::Poll; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; + /// Returns a stream over the entries within a directory. /// /// This is an async version of [`std::fs::read_dir`](std::fs::read_dir) @@ -45,7 +54,7 @@ pub struct ReadDir(State); #[derive(Debug)] enum State { Idle(Option), - Pending(sys::Blocking<(Option>, std::fs::ReadDir)>), + Pending(JoinHandle<(Option>, std::fs::ReadDir)>), } impl ReadDir { @@ -79,7 +88,7 @@ impl ReadDir { State::Idle(ref mut std) => { let mut std = std.take().unwrap(); - self.0 = State::Pending(sys::run(move || { + self.0 = State::Pending(spawn_blocking(move || { let ret = std.next(); (ret, std) })); diff --git a/tokio/src/loom/std/atomic_u64.rs b/tokio/src/loom/std/atomic_u64.rs index a86a195b1d2..7eb457a2405 100644 --- a/tokio/src/loom/std/atomic_u64.rs +++ b/tokio/src/loom/std/atomic_u64.rs @@ -15,8 +15,8 @@ mod imp { #[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] mod imp { + use crate::loom::sync::Mutex; use std::sync::atomic::Ordering; - use std::sync::Mutex; #[derive(Debug)] pub(crate) struct AtomicU64 { @@ -31,15 +31,15 @@ mod imp { } pub(crate) fn load(&self, _: Ordering) -> u64 { - *self.inner.lock().unwrap() + *self.inner.lock() } pub(crate) fn store(&self, val: u64, _: Ordering) { - *self.inner.lock().unwrap() = val; + *self.inner.lock() = val; } pub(crate) fn fetch_or(&self, val: u64, _: Ordering) -> u64 { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); let prev = *lock; *lock = prev | val; prev @@ -52,7 +52,7 @@ mod imp { _success: Ordering, _failure: Ordering, ) -> Result { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); if *lock == current { *lock = new; diff --git a/tokio/src/runtime/shell.rs b/tokio/src/runtime/shell.rs deleted file mode 100644 index 486d4fa5bbe..00000000000 --- a/tokio/src/runtime/shell.rs +++ /dev/null @@ -1,132 +0,0 @@ -#![allow(clippy::redundant_clone)] - -use crate::future::poll_fn; -use crate::park::{Park, Unpark}; -use crate::runtime::driver::Driver; -use crate::sync::Notify; -use crate::util::{waker_ref, Wake}; - -use std::sync::{Arc, Mutex}; -use std::task::Context; -use std::task::Poll::{Pending, Ready}; -use std::{future::Future, sync::PoisonError}; - -#[derive(Debug)] -pub(super) struct Shell { - driver: Mutex>, - - notify: Notify, - - /// TODO: don't store this - unpark: Arc, -} - -#[derive(Debug)] -struct Handle(::Unpark); - -impl Shell { - pub(super) fn new(driver: Driver) -> Shell { - let unpark = Arc::new(Handle(driver.unpark())); - - Shell { - driver: Mutex::new(Some(driver)), - notify: Notify::new(), - unpark, - } - } - - pub(super) fn block_on(&self, f: F) -> F::Output - where - F: Future, - { - let mut enter = crate::runtime::enter(true); - - pin!(f); - - loop { - if let Some(driver) = &mut self.take_driver() { - return driver.block_on(f); - } else { - let notified = self.notify.notified(); - pin!(notified); - - if let Some(out) = enter - .block_on(poll_fn(|cx| { - if notified.as_mut().poll(cx).is_ready() { - return Ready(None); - } - - if let Ready(out) = f.as_mut().poll(cx) { - return Ready(Some(out)); - } - - Pending - })) - .expect("Failed to `Enter::block_on`") - { - return out; - } - } - } - } - - fn take_driver(&self) -> Option> { - let mut lock = self.driver.lock().unwrap(); - let driver = lock.take()?; - - Some(DriverGuard { - inner: Some(driver), - shell: &self, - }) - } -} - -impl Wake for Handle { - /// Wake by value - fn wake(self: Arc) { - Wake::wake_by_ref(&self); - } - - /// Wake by reference - fn wake_by_ref(arc_self: &Arc) { - arc_self.0.unpark(); - } -} - -struct DriverGuard<'a> { - inner: Option, - shell: &'a Shell, -} - -impl DriverGuard<'_> { - fn block_on(&mut self, f: F) -> F::Output { - let driver = self.inner.as_mut().unwrap(); - - pin!(f); - - let waker = waker_ref(&self.shell.unpark); - let mut cx = Context::from_waker(&waker); - - loop { - if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { - return v; - } - - driver.park().unwrap(); - } - } -} - -impl Drop for DriverGuard<'_> { - fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - self.shell - .driver - .lock() - .unwrap_or_else(PoisonError::into_inner) - .replace(inner); - - self.shell.notify.notify_one(); - } - } -} diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 7f1c4e4cb0c..8cd649dc7f5 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -112,6 +112,8 @@ where } pub(super) fn drop_join_handle_slow(self) { + let mut maybe_panic = None; + // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. if self.header().state.unset_join_interested().is_err() { @@ -120,11 +122,20 @@ where // the scheduler or `JoinHandle`. i.e. if the output remains in the // task structure until the task is deallocated, it may be dropped // by a Waker on any arbitrary thread. - self.core().stage.drop_future_or_output(); + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + if let Err(panic) = panic { + maybe_panic = Some(panic); + } } // Drop the `JoinHandle` reference, possibly deallocating the task self.drop_reference(); + + if let Some(panic) = maybe_panic { + panic::resume_unwind(panic); + } } // ===== waker behavior ===== @@ -183,17 +194,25 @@ where // ====== internal ====== fn complete(self, output: super::Result, is_join_interested: bool) { - if is_join_interested { - // Store the output. The future has already been dropped - // - // Safety: Mutual exclusion is obtained by having transitioned the task - // state -> Running - let stage = &self.core().stage; - stage.store_output(output); - - // Transition to `Complete`, notifying the `JoinHandle` if necessary. - transition_to_complete(self.header(), stage, &self.trailer()); - } + // We catch panics here because dropping the output may panic. + // + // Dropping the output can also happen in the first branch inside + // transition_to_complete. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if is_join_interested { + // Store the output. The future has already been dropped + // + // Safety: Mutual exclusion is obtained by having transitioned the task + // state -> Running + let stage = &self.core().stage; + stage.store_output(output); + + // Transition to `Complete`, notifying the `JoinHandle` if necessary. + transition_to_complete(self.header(), stage, &self.trailer()); + } else { + drop(output); + } + })); // The task has completed execution and will no longer be scheduled. // @@ -420,7 +439,7 @@ fn poll_future( cx: Context<'_>, ) -> PollFuture { if snapshot.is_cancelled() { - PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested()) + PollFuture::Complete(Err(cancel_task(core)), snapshot.is_join_interested()) } else { let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { struct Guard<'a, T: Future> { diff --git a/tokio/src/runtime/tests/loom_local.rs b/tokio/src/runtime/tests/loom_local.rs new file mode 100644 index 00000000000..d9a07a45f05 --- /dev/null +++ b/tokio/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/tokio/src/runtime/tests/loom_oneshot.rs b/tokio/src/runtime/tests/loom_oneshot.rs index c126fe479af..87eb6386425 100644 --- a/tokio/src/runtime/tests/loom_oneshot.rs +++ b/tokio/src/runtime/tests/loom_oneshot.rs @@ -1,7 +1,6 @@ +use crate::loom::sync::{Arc, Mutex}; use loom::sync::Notify; -use std::sync::{Arc, Mutex}; - pub(crate) fn channel() -> (Sender, Receiver) { let inner = Arc::new(Inner { notify: Notify::new(), @@ -31,7 +30,7 @@ struct Inner { impl Sender { pub(crate) fn send(self, value: T) { - *self.inner.value.lock().unwrap() = Some(value); + *self.inner.value.lock() = Some(value); self.inner.notify.notify(); } } @@ -39,7 +38,7 @@ impl Sender { impl Receiver { pub(crate) fn recv(self) -> T { loop { - if let Some(v) = self.inner.value.lock().unwrap().take() { + if let Some(v) = self.inner.value.lock().take() { return v; } diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 596e47dfd00..3f2cc9825e8 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -21,6 +21,7 @@ mod joinable_wrapper { cfg_loom! { mod loom_basic_scheduler; + mod loom_local; mod loom_blocking; mod loom_oneshot; mod loom_pool; @@ -31,6 +32,9 @@ cfg_loom! { cfg_not_loom! { mod queue; + #[cfg(not(miri))] + mod task_combinations; + #[cfg(miri)] mod task; } diff --git a/tokio/src/runtime/tests/task_combinations.rs b/tokio/src/runtime/tests/task_combinations.rs new file mode 100644 index 00000000000..76ce2330c2c --- /dev/null +++ b/tokio/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn(&self, task: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper { + inner: F, + on_drop: Option>, + panic_on_drop: bool, + } + impl Future for FutWrapper { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl Drop for FutWrapper { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option>, + wait_complete: Option>, + on_output_drop: Option>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} diff --git a/tokio/src/sync/barrier.rs b/tokio/src/sync/barrier.rs index e3c95f6a6b6..0e39dac8bb5 100644 --- a/tokio/src/sync/barrier.rs +++ b/tokio/src/sync/barrier.rs @@ -1,7 +1,6 @@ +use crate::loom::sync::Mutex; use crate::sync::watch; -use std::sync::Mutex; - /// A barrier enables multiple tasks to synchronize the beginning of some computation. /// /// ``` @@ -94,7 +93,7 @@ impl Barrier { // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across // a yield point, and thus marks the returned future as !Send. let generation = { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); let generation = state.generation; state.arrived += 1; if state.arrived == self.n { diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 0df6037a1c2..cb4649d86df 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -179,9 +179,19 @@ struct Inner { value: UnsafeCell>, /// The task to notify when the receiver drops without consuming the value. + /// + /// ## Safety + /// + /// The `TX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. tx_task: Task, /// The task to notify when the value is sent. + /// + /// ## Safety + /// + /// The `RX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. rx_task: Task, } @@ -311,11 +321,24 @@ impl Sender { let inner = self.inner.take().unwrap(); inner.value.with_mut(|ptr| unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless the + // channel has been marked as "complete" (the `VALUE_SENT` state bit + // is set). + // That bit is only set by the sender later on in this method, and + // calling this method consumes `self`. Therefore, if it was possible to + // call this method, we know that the `VALUE_SENT` bit is unset, and + // the receiver is not currently accessing the `UnsafeCell`. *ptr = Some(t); }); if !inner.complete() { unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless + // the channel has been marked as "complete". Calling + // `complete()` will return true if this bit is set, and false + // if it is not set. Thus, if `complete()` returned false, it is + // safe for us to access the value, because we know that the + // receiver will not. return Err(inner.consume_value().unwrap()); } } @@ -661,6 +684,11 @@ impl Receiver { let state = State::load(&inner.state, Acquire); if state.is_complete() { + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. match unsafe { inner.consume_value() } { Some(value) => Ok(value), None => Err(TryRecvError::Closed), @@ -751,6 +779,11 @@ impl Inner { State::set_rx_task(&self.state); coop.made_progress(); + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -797,6 +830,14 @@ impl Inner { } /// Consumes the value. This function does not check `state`. + /// + /// # Safety + /// + /// Calling this method concurrently on multiple threads will result in a + /// data race. The `VALUE_SENT` state bit is used to ensure that only the + /// sender *or* the receiver will call this method at a given point in time. + /// If `VALUE_SENT` is not set, then only the sender may call this method; + /// if it is set, then only the receiver may call this method. unsafe fn consume_value(&self) -> Option { self.value.with_mut(|ptr| (*ptr).take()) } @@ -837,9 +878,28 @@ impl fmt::Debug for Inner { } } +/// Indicates that a waker for the receiving task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `rx_task` field may be uninitialized. const RX_TASK_SET: usize = 0b00001; +/// Indicates that a value has been stored in the channel's inner `UnsafeCell`. +/// +/// # Safety +/// +/// This bit controls which side of the channel is permitted to access the +/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the +/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by +/// the sender. const VALUE_SENT: usize = 0b00010; const CLOSED: usize = 0b00100; + +/// Indicates that a waker for the sending task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `tx_task` field may be uninitialized. const TX_TASK_SET: usize = 0b01000; impl State { @@ -852,11 +912,38 @@ impl State { } fn set_complete(cell: &AtomicUsize) -> State { - // TODO: This could be `Release`, followed by an `Acquire` fence *if* - // the `RX_TASK_SET` flag is set. However, `loom` does not support - // fences yet. - let val = cell.fetch_or(VALUE_SENT, AcqRel); - State(val) + // This method is a compare-and-swap loop rather than a fetch-or like + // other `set_$WHATEVER` methods on `State`. This is because we must + // check if the state has been closed before setting the `VALUE_SENT` + // bit. + // + // We don't want to set both the `VALUE_SENT` bit if the `CLOSED` + // bit is already set, because `VALUE_SENT` will tell the receiver that + // it's okay to access the inner `UnsafeCell`. Immediately after calling + // `set_complete`, if the channel was closed, the sender will _also_ + // access the `UnsafeCell` to take the value back out, so if a + // `poll_recv` or `try_recv` call is occurring concurrently, both + // threads may try to access the `UnsafeCell` if we were to set the + // `VALUE_SENT` bit on a closed channel. + let mut state = cell.load(Ordering::Relaxed); + loop { + if State(state).is_closed() { + break; + } + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + match cell.compare_exchange_weak( + state, + state | VALUE_SENT, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(actual) => state = actual, + } + } + State(state) } fn is_rx_task_set(self) -> bool { diff --git a/tokio/src/sync/tests/loom_oneshot.rs b/tokio/src/sync/tests/loom_oneshot.rs index 9729cfb73d3..c5f79720794 100644 --- a/tokio/src/sync/tests/loom_oneshot.rs +++ b/tokio/src/sync/tests/loom_oneshot.rs @@ -55,6 +55,35 @@ fn changing_rx_task() { }); } +#[test] +fn try_recv_close() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + thread::spawn(move || { + let _ = tx.send(()); + }); + + rx.close(); + let _ = rx.try_recv(); + }) +} + +#[test] +fn recv_closed() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + let _ = tx.send(1); + }); + + rx.close(); + let _ = block_on(rx); + }); +} + // TODO: Move this into `oneshot` proper. use std::future::Future; diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 49b0ec6c4d4..37c2c508ad3 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,4 +1,5 @@ //! Runs `!Send` futures on the current thread. +use crate::loom::sync::{Arc, Mutex}; use crate::runtime::task::{self, JoinHandle, Task}; use crate::sync::AtomicWaker; use crate::util::linked_list::{Link, LinkedList}; @@ -9,7 +10,6 @@ use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::sync::{Arc, Mutex}; use std::task::Poll; use pin_project_lite::pin_project; @@ -242,7 +242,7 @@ struct Tasks { /// LocalSet state shared between threads. struct Shared { /// Remote run queue sender - queue: Mutex>>>, + queue: Mutex>>>>, /// Wake the `LocalSet` task waker: AtomicWaker, @@ -338,7 +338,7 @@ impl LocalSet { queue: VecDeque::with_capacity(INITIAL_CAPACITY), }), shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), }), }, @@ -538,8 +538,8 @@ impl LocalSet { .shared .queue .lock() - .unwrap() - .pop_front() + .as_mut() + .and_then(|queue| queue.pop_front()) .or_else(|| self.context.tasks.borrow_mut().queue.pop_front()) } else { self.context @@ -547,7 +547,14 @@ impl LocalSet { .borrow_mut() .queue .pop_front() - .or_else(|| self.context.shared.queue.lock().unwrap().pop_front()) + .or_else(|| { + self.context + .shared + .queue + .lock() + .as_mut() + .and_then(|queue| queue.pop_front()) + }) } } @@ -611,7 +618,10 @@ impl Drop for LocalSet { task.shutdown(); } - for task in self.context.shared.queue.lock().unwrap().drain(..) { + // Take the queue from the Shared object to prevent pushing + // notifications to it in the future. + let queue = self.context.shared.queue.lock().take().unwrap(); + for task in queue { task.shutdown(); } @@ -661,8 +671,16 @@ impl Shared { cx.tasks.borrow_mut().queue.push_back(task); } _ => { - self.queue.lock().unwrap().push_back(task); - self.waker.wake(); + // First check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); + } } }); } diff --git a/tokio/src/time/clock.rs b/tokio/src/time/clock.rs index b9ec5c5aab3..a44d75f3ce1 100644 --- a/tokio/src/time/clock.rs +++ b/tokio/src/time/clock.rs @@ -29,7 +29,7 @@ cfg_not_test_util! { cfg_test_util! { use crate::time::{Duration, Instant}; - use std::sync::{Arc, Mutex}; + use crate::loom::sync::{Arc, Mutex}; cfg_rt! { fn clock() -> Option { @@ -102,7 +102,7 @@ cfg_test_util! { /// runtime. pub fn resume() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); - let mut inner = clock.inner.lock().unwrap(); + let mut inner = clock.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -164,7 +164,7 @@ cfg_test_util! { } pub(crate) fn pause(&self) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if !inner.enable_pausing { drop(inner); // avoid poisoning the lock @@ -178,12 +178,12 @@ cfg_test_util! { } pub(crate) fn is_paused(&self) -> bool { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); inner.unfrozen.is_none() } pub(crate) fn advance(&self, duration: Duration) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -193,7 +193,7 @@ cfg_test_util! { } pub(crate) fn now(&self) -> Instant { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); let mut ret = inner.base; diff --git a/tokio/tests/io_async_fd.rs b/tokio/tests/io_async_fd.rs index d1586bb36d5..dc21e426f45 100644 --- a/tokio/tests/io_async_fd.rs +++ b/tokio/tests/io_async_fd.rs @@ -13,7 +13,6 @@ use std::{ task::{Context, Waker}, }; -use nix::errno::Errno; use nix::unistd::{close, read, write}; use futures::{poll, FutureExt}; @@ -56,10 +55,6 @@ impl TestWaker { } } -fn is_blocking(e: &nix::Error) -> bool { - Some(Errno::EAGAIN) == e.as_errno() -} - #[derive(Debug)] struct FileDescriptor { fd: RawFd, @@ -73,11 +68,7 @@ impl AsRawFd for FileDescriptor { impl Read for &FileDescriptor { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match read(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + read(self.fd, buf).map_err(io::Error::from) } } @@ -89,11 +80,7 @@ impl Read for FileDescriptor { impl Write for &FileDescriptor { fn write(&mut self, buf: &[u8]) -> io::Result { - match write(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + write(self.fd, buf).map_err(io::Error::from) } fn flush(&mut self) -> io::Result<()> { diff --git a/tokio/tests/macros_test.rs b/tokio/tests/macros_test.rs index 7212c7ba183..bca2c9198a0 100644 --- a/tokio/tests/macros_test.rs +++ b/tokio/tests/macros_test.rs @@ -30,3 +30,19 @@ fn trait_method() { } ().f() } + +// https://github.com/tokio-rs/tokio/issues/4175 +#[tokio::main] +pub async fn issue_4175_main_1() -> ! { + panic!(); +} +#[tokio::main] +pub async fn issue_4175_main_2() -> std::io::Result<()> { + panic!(); +} +#[allow(unreachable_code)] +#[tokio::test] +pub async fn issue_4175_test() -> std::io::Result<()> { + return Ok(()); + panic!(); +} diff --git a/tokio/tests/support/mock_file.rs b/tokio/tests/support/mock_file.rs deleted file mode 100644 index 1ce326b62aa..00000000000 --- a/tokio/tests/support/mock_file.rs +++ /dev/null @@ -1,295 +0,0 @@ -#![allow(clippy::unnecessary_operation)] - -use std::collections::VecDeque; -use std::fmt; -use std::fs::{Metadata, Permissions}; -use std::io; -use std::io::prelude::*; -use std::io::SeekFrom; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; - -pub struct File { - shared: Arc>, -} - -pub struct Handle { - shared: Arc>, -} - -struct Shared { - calls: VecDeque, -} - -#[derive(Debug)] -enum Call { - Read(io::Result>), - Write(io::Result>), - Seek(SeekFrom, io::Result), - SyncAll(io::Result<()>), - SyncData(io::Result<()>), - SetLen(u64, io::Result<()>), -} - -impl Handle { - pub fn read(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Read(Ok(data.to_owned()))); - self - } - - pub fn read_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Read(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn write(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Write(Ok(data.to_owned()))); - self - } - - pub fn write_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Write(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn seek_start_ok(&self, offset: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Start(offset), Ok(offset))); - self - } - - pub fn seek_current_ok(&self, offset: i64, ret: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Current(offset), Ok(ret))); - self - } - - pub fn sync_all(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncAll(Ok(()))); - self - } - - pub fn sync_all_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncAll(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn sync_data(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncData(Ok(()))); - self - } - - pub fn sync_data_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncData(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn set_len(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SetLen(size, Ok(()))); - self - } - - pub fn set_len_err(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SetLen(size, Err(io::ErrorKind::Other.into()))); - self - } - - pub fn remaining(&self) -> usize { - let s = self.shared.lock().unwrap(); - s.calls.len() - } -} - -impl Drop for Handle { - fn drop(&mut self) { - if !std::thread::panicking() { - let s = self.shared.lock().unwrap(); - assert_eq!(0, s.calls.len()); - } - } -} - -impl File { - pub fn open(_: PathBuf) -> io::Result { - unimplemented!(); - } - - pub fn create(_: PathBuf) -> io::Result { - unimplemented!(); - } - - pub fn mock() -> (Handle, File) { - let shared = Arc::new(Mutex::new(Shared { - calls: VecDeque::new(), - })); - - let handle = Handle { - shared: shared.clone(), - }; - let file = File { shared }; - - (handle, file) - } - - pub fn sync_all(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncAll(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn sync_data(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncData(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn set_len(&self, size: u64) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SetLen(arg, ret)) => { - assert_eq!(arg, size); - ret - } - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn metadata(&self) -> io::Result { - unimplemented!(); - } - - pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()> { - unimplemented!(); - } - - pub fn try_clone(&self) -> io::Result { - unimplemented!(); - } -} - -impl Read for &'_ File { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Read(Ok(data))) => { - assert!(dst.len() >= data.len()); - assert!(dst.len() <= 16 * 1024, "actual = {}", dst.len()); // max buffer - - &mut dst[..data.len()].copy_from_slice(&data); - Ok(data.len()) - } - Some(Read(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was a read", op), - None => panic!("did not expect call"), - } - } -} - -impl Write for &'_ File { - fn write(&mut self, src: &[u8]) -> io::Result { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Write(Ok(data))) => { - assert_eq!(src, &data[..]); - Ok(src.len()) - } - Some(Write(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was write", op), - None => panic!("did not expect call"), - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl Seek for &'_ File { - fn seek(&mut self, pos: SeekFrom) -> io::Result { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Seek(expect, res)) => { - assert_eq!(expect, pos); - res - } - Some(op) => panic!("expected call {:?}; was `seek`", op), - None => panic!("did not expect call; was `seek`"), - } - } -} - -impl fmt::Debug for File { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("mock::File").finish() - } -} - -#[cfg(unix)] -impl std::os::unix::io::AsRawFd for File { - fn as_raw_fd(&self) -> std::os::unix::io::RawFd { - unimplemented!(); - } -} - -#[cfg(unix)] -impl std::os::unix::io::FromRawFd for File { - unsafe fn from_raw_fd(_: std::os::unix::io::RawFd) -> Self { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::AsRawHandle for File { - fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::FromRawHandle for File { - unsafe fn from_raw_handle(_: std::os::windows::io::RawHandle) -> Self { - unimplemented!(); - } -} diff --git a/tokio/tests/support/mock_pool.rs b/tokio/tests/support/mock_pool.rs deleted file mode 100644 index e1fdb426417..00000000000 --- a/tokio/tests/support/mock_pool.rs +++ /dev/null @@ -1,66 +0,0 @@ -use tokio::sync::oneshot; - -use std::cell::RefCell; -use std::collections::VecDeque; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -thread_local! { - static QUEUE: RefCell>> = RefCell::new(VecDeque::new()) -} - -#[derive(Debug)] -pub(crate) struct Blocking { - rx: oneshot::Receiver, -} - -pub(crate) fn run(f: F) -> Blocking -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - let (tx, rx) = oneshot::channel(); - let task = Box::new(move || { - let _ = tx.send(f()); - }); - - QUEUE.with(|cell| cell.borrow_mut().push_back(task)); - - Blocking { rx } -} - -impl Future for Blocking { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - use std::task::Poll::*; - - match Pin::new(&mut self.rx).poll(cx) { - Ready(Ok(v)) => Ready(Ok(v)), - Ready(Err(e)) => panic!("error = {:?}", e), - Pending => Pending, - } - } -} - -pub(crate) async fn asyncify(f: F) -> io::Result -where - F: FnOnce() -> io::Result + Send + 'static, - T: Send + 'static, -{ - run(f).await? -} - -pub(crate) fn len() -> usize { - QUEUE.with(|cell| cell.borrow().len()) -} - -pub(crate) fn run_one() { - let task = QUEUE - .with(|cell| cell.borrow_mut().pop_front()) - .expect("expected task to run, but none ready"); - - task(); -} diff --git a/tokio/tests/task_abort.rs b/tokio/tests/task_abort.rs index c524dc287d1..cdaa405b86a 100644 --- a/tokio/tests/task_abort.rs +++ b/tokio/tests/task_abort.rs @@ -1,14 +1,25 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use std::sync::Arc; use std::thread::sleep; use std::time::Duration; +use tokio::runtime::Builder; + +struct PanicOnDrop; + +impl Drop for PanicOnDrop { + fn drop(&mut self) { + panic!("Well what did you expect would happen..."); + } +} + /// Checks that a suspended task can be aborted without panicking as reported in /// issue #3157: . #[test] fn test_abort_without_panic_3157() { - let rt = tokio::runtime::Builder::new_multi_thread() + let rt = Builder::new_multi_thread() .enable_time() .worker_threads(1) .build() @@ -44,9 +55,7 @@ fn test_abort_without_panic_3662() { } } - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(); + let rt = Builder::new_current_thread().build().unwrap(); rt.block_on(async move { let drop_flag = Arc::new(AtomicBool::new(false)); @@ -119,9 +128,7 @@ fn remote_abort_local_set_3929() { } } - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(); + let rt = Builder::new_current_thread().build().unwrap(); let local = tokio::task::LocalSet::new(); let check = DropCheck::new(); @@ -138,3 +145,80 @@ fn remote_abort_local_set_3929() { rt.block_on(local); jh2.join().unwrap(); } + +/// Checks that a suspended task can be aborted even if the `JoinHandle` is immediately dropped. +/// issue #3964: . +#[test] +fn test_abort_wakes_task_3964() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _notify_dropped = notify_dropped; + println!("task started"); + tokio::time::sleep(std::time::Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Check that the Arc has been dropped. + assert!(weak_notify_dropped.upgrade().is_none()); + }); +} + +/// Checks that aborting a task whose destructor panics does not allow the +/// panic to escape the task. +#[test] +fn test_abort_task_that_panics_on_drop_contained() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(std::time::Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + }); +} + +/// Checks that aborting a task whose destructor panics has the expected result. +#[test] +fn test_abort_task_that_panics_on_drop_returned() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(std::time::Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + handle.abort(); + assert!(handle.await.unwrap_err().is_panic()); + }); +} diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index 10314bebf9c..4d2846865f5 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -87,9 +87,12 @@ async fn try_send_recv_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len()); @@ -291,9 +294,12 @@ async fn try_recv_buf_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len());