diff --git a/tokio/src/blocking.rs b/tokio/src/blocking.rs index f88b1db11cc..f172399d5ef 100644 --- a/tokio/src/blocking.rs +++ b/tokio/src/blocking.rs @@ -1,5 +1,11 @@ cfg_rt! { pub(crate) use crate::runtime::spawn_blocking; + + cfg_fs! { + #[allow(unused_imports)] + pub(crate) use crate::runtime::spawn_mandatory_blocking; + } + pub(crate) use crate::task::JoinHandle; } @@ -16,7 +22,16 @@ cfg_not_rt! { { assert_send_sync::>>(); panic!("requires the `rt` Tokio feature flag") + } + cfg_fs! { + pub(crate) fn spawn_mandatory_blocking(_f: F) -> Option> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + panic!("requires the `rt` Tokio feature flag") + } } pub(crate) struct JoinHandle { diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 61071cf6309..2c38e8059f5 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -19,17 +19,17 @@ use std::task::Context; use std::task::Poll; use std::task::Poll::*; -#[cfg(test)] -use super::mocks::spawn_blocking; #[cfg(test)] use super::mocks::JoinHandle; #[cfg(test)] use super::mocks::MockFile as StdFile; -#[cfg(not(test))] -use crate::blocking::spawn_blocking; +#[cfg(test)] +use super::mocks::{spawn_blocking, spawn_mandatory_blocking}; #[cfg(not(test))] use crate::blocking::JoinHandle; #[cfg(not(test))] +use crate::blocking::{spawn_blocking, spawn_mandatory_blocking}; +#[cfg(not(test))] use std::fs::File as StdFile; /// A reference to an open file on the filesystem. @@ -649,7 +649,7 @@ impl AsyncWrite for File { let n = buf.copy_from(src); let std = me.std.clone(); - inner.state = Busy(spawn_blocking(move || { + let blocking_task_join_handle = spawn_mandatory_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -657,7 +657,12 @@ impl AsyncWrite for File { }; (Operation::Write(res), buf) - })); + }) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "background task failed") + })?; + + inner.state = Busy(blocking_task_join_handle); return Ready(Ok(n)); } diff --git a/tokio/src/fs/mocks.rs b/tokio/src/fs/mocks.rs index 68ef4f3a7a4..b1861726778 100644 --- a/tokio/src/fs/mocks.rs +++ b/tokio/src/fs/mocks.rs @@ -105,6 +105,21 @@ where JoinHandle { rx } } +pub(super) fn spawn_mandatory_blocking(f: F) -> Option> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + Some(JoinHandle { rx }) +} + impl Future for JoinHandle { type Output = Result; diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index 670ec3a4b34..15fe05c9ade 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -4,7 +4,11 @@ //! compilation. mod pool; -pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; +pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task}; + +cfg_fs! { + pub(crate) use pool::spawn_mandatory_blocking; +} mod schedule; mod shutdown; diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index bb6c1ee6606..f0d7d4e0fe9 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -70,11 +70,40 @@ struct Shared { worker_thread_index: usize, } -type Task = task::UnownedTask; +pub(crate) struct Task { + task: task::UnownedTask, + mandatory: Mandatory, +} + +#[derive(PartialEq, Eq)] +pub(crate) enum Mandatory { + #[cfg_attr(not(fs), allow(dead_code))] + Mandatory, + NonMandatory, +} + +impl Task { + pub(crate) fn new(task: task::UnownedTask, mandatory: Mandatory) -> Task { + Task { task, mandatory } + } + + fn run(self) { + self.task.run(); + } + + fn shutdown_or_run_if_mandatory(self) { + match self.mandatory { + Mandatory::NonMandatory => self.task.shutdown(), + Mandatory::Mandatory => self.task.run(), + } + } +} const KEEP_ALIVE: Duration = Duration::from_secs(10); /// Runs the provided function on an executor dedicated to blocking operations. +/// Tasks will be scheduled as non-mandatory, meaning they may not get executed +/// in case of runtime shutdown. pub(crate) fn spawn_blocking(func: F) -> JoinHandle where F: FnOnce() -> R + Send + 'static, @@ -84,6 +113,25 @@ where rt.spawn_blocking(func) } +cfg_fs! { + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + /// Runs the provided function on an executor dedicated to blocking + /// operations. Tasks will be scheduled as mandatory, meaning they are + /// guaranteed to run unless a shutdown is already taking place. In case a + /// shutdown is already taking place, `None` will be returned. + pub(crate) fn spawn_mandatory_blocking(func: F) -> Option> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let rt = context::current(); + rt.spawn_mandatory_blocking(func) + } +} + // ===== impl BlockingPool ===== impl BlockingPool { @@ -176,8 +224,10 @@ impl Spawner { let mut shared = self.inner.shared.lock(); if shared.shutdown { - // Shutdown the task - task.shutdown(); + // Shutdown the task: it's fine to shutdown this task (even if + // mandatory) because it was scheduled after the shutdown of the + // runtime began. + task.task.shutdown(); // no need to even push this task; it would never get picked up return Err(()); @@ -302,7 +352,8 @@ impl Inner { // Drain the queue while let Some(task) = shared.queue.pop_front() { drop(shared); - task.shutdown(); + + task.shutdown_or_run_if_mandatory(); shared = self.shared.lock(); } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 3481a2552f3..9dbe6774dd0 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -189,15 +189,56 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { - if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { - self.spawn_blocking_inner(Box::new(func), None) - } else { - self.spawn_blocking_inner(func, None) + let (join_handle, _was_spawned) = + if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { + self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None) + } else { + self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None) + }; + + join_handle + } + + cfg_fs! { + #[track_caller] + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + pub(crate) fn spawn_mandatory_blocking(&self, func: F) -> Option> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { + self.spawn_blocking_inner( + Box::new(func), + blocking::Mandatory::Mandatory, + None + ) + } else { + self.spawn_blocking_inner( + func, + blocking::Mandatory::Mandatory, + None + ) + }; + + if was_spawned { + Some(join_handle) + } else { + None + } } } #[track_caller] - pub(crate) fn spawn_blocking_inner(&self, func: F, name: Option<&str>) -> JoinHandle + pub(crate) fn spawn_blocking_inner( + &self, + func: F, + is_mandatory: blocking::Mandatory, + name: Option<&str>, + ) -> (JoinHandle, bool) where F: FnOnce() -> R + Send + 'static, R: Send + 'static, @@ -223,8 +264,10 @@ impl Handle { let _ = name; let (task, handle) = task::unowned(fut, NoopSchedule); - let _ = self.blocking_spawner.spawn(task, self); - handle + let spawned = self + .blocking_spawner + .spawn(blocking::Task::new(task, is_mandatory), self); + (handle, spawned.is_ok()) } /// Runs a future to completion on this `Handle`'s associated `Runtime`. diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index c8d97e1b19a..b607d72f0dc 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -201,6 +201,14 @@ cfg_rt! { use blocking::BlockingPool; pub(crate) use blocking::spawn_blocking; + cfg_trace! { + pub(crate) use blocking::Mandatory; + } + + cfg_fs! { + pub(crate) use blocking::spawn_mandatory_blocking; + } + mod builder; pub use self::builder::Builder; diff --git a/tokio/src/runtime/tests/loom_blocking.rs b/tokio/src/runtime/tests/loom_blocking.rs index 8fb54c5657e..89de85e4362 100644 --- a/tokio/src/runtime/tests/loom_blocking.rs +++ b/tokio/src/runtime/tests/loom_blocking.rs @@ -23,6 +23,56 @@ fn blocking_shutdown() { }); } +#[test] +fn spawn_mandatory_blocking_should_always_run() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + + let (tx, rx) = loom_oneshot::channel(); + let _enter = rt.enter(); + runtime::spawn_blocking(|| {}); + runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }) + .unwrap(); + + drop(rt); + + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + let handle = rt.handle().clone(); + + // Drop the runtime in a different thread + { + loom::thread::spawn(move || { + drop(rt); + }); + } + + let _enter = handle.enter(); + let (tx, rx) = loom_oneshot::channel(); + let handle = runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }); + + // handle.is_some() means that `spawn_mandatory_blocking` + // promised us to run the blocking task + if handle.is_some() { + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + } + }); +} + fn mk_runtime(num_threads: usize) -> Runtime { runtime::Builder::new_multi_thread() .worker_threads(num_threads) diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 0a7fe3c371a..5a128420ee2 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -107,6 +107,9 @@ impl<'a> Builder<'a> { Function: FnOnce() -> Output + Send + 'static, Output: Send + 'static, { - context::current().spawn_blocking_inner(function, self.name) + use crate::runtime::Mandatory; + let (join_handle, _was_spawned) = + context::current().spawn_blocking_inner(function, Mandatory::NonMandatory, self.name); + join_handle } }