Skip to content

Commit

Permalink
fs: guarantee that File::write will attempt the write even if the r…
Browse files Browse the repository at this point in the history
…untime shuts down (#4316)
  • Loading branch information
BraulioVM committed Jan 25, 2022
1 parent 9e38ebc commit 7aad428
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 19 deletions.
15 changes: 15 additions & 0 deletions tokio/src/blocking.rs
@@ -1,5 +1,11 @@
cfg_rt! {
pub(crate) use crate::runtime::spawn_blocking;

cfg_fs! {
#[allow(unused_imports)]
pub(crate) use crate::runtime::spawn_mandatory_blocking;
}

pub(crate) use crate::task::JoinHandle;
}

Expand All @@ -16,7 +22,16 @@ cfg_not_rt! {
{
assert_send_sync::<JoinHandle<std::cell::Cell<()>>>();
panic!("requires the `rt` Tokio feature flag")
}

cfg_fs! {
pub(crate) fn spawn_mandatory_blocking<F, R>(_f: F) -> Option<JoinHandle<R>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
panic!("requires the `rt` Tokio feature flag")
}
}

pub(crate) struct JoinHandle<R> {
Expand Down
17 changes: 11 additions & 6 deletions tokio/src/fs/file.rs
Expand Up @@ -19,17 +19,17 @@ use std::task::Context;
use std::task::Poll;
use std::task::Poll::*;

#[cfg(test)]
use super::mocks::spawn_blocking;
#[cfg(test)]
use super::mocks::JoinHandle;
#[cfg(test)]
use super::mocks::MockFile as StdFile;
#[cfg(not(test))]
use crate::blocking::spawn_blocking;
#[cfg(test)]
use super::mocks::{spawn_blocking, spawn_mandatory_blocking};
#[cfg(not(test))]
use crate::blocking::JoinHandle;
#[cfg(not(test))]
use crate::blocking::{spawn_blocking, spawn_mandatory_blocking};
#[cfg(not(test))]
use std::fs::File as StdFile;

/// A reference to an open file on the filesystem.
Expand Down Expand Up @@ -649,15 +649,20 @@ impl AsyncWrite for File {
let n = buf.copy_from(src);
let std = me.std.clone();

inner.state = Busy(spawn_blocking(move || {
let blocking_task_join_handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
}));
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;

inner.state = Busy(blocking_task_join_handle);

return Ready(Ok(n));
}
Expand Down
15 changes: 15 additions & 0 deletions tokio/src/fs/mocks.rs
Expand Up @@ -105,6 +105,21 @@ where
JoinHandle { rx }
}

pub(super) fn spawn_mandatory_blocking<F, R>(f: F) -> Option<JoinHandle<R>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let task = Box::new(move || {
let _ = tx.send(f());
});

QUEUE.with(|cell| cell.borrow_mut().push_back(task));

Some(JoinHandle { rx })
}

impl<T> Future for JoinHandle<T> {
type Output = Result<T, io::Error>;

Expand Down
6 changes: 5 additions & 1 deletion tokio/src/runtime/blocking/mod.rs
Expand Up @@ -4,7 +4,11 @@
//! compilation.

mod pool;
pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner};
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task};

cfg_fs! {
pub(crate) use pool::spawn_mandatory_blocking;
}

mod schedule;
mod shutdown;
Expand Down
59 changes: 55 additions & 4 deletions tokio/src/runtime/blocking/pool.rs
Expand Up @@ -70,11 +70,40 @@ struct Shared {
worker_thread_index: usize,
}

type Task = task::UnownedTask<NoopSchedule>;
pub(crate) struct Task {
task: task::UnownedTask<NoopSchedule>,
mandatory: Mandatory,
}

#[derive(PartialEq, Eq)]
pub(crate) enum Mandatory {
#[cfg_attr(not(fs), allow(dead_code))]
Mandatory,
NonMandatory,
}

impl Task {
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
Task { task, mandatory }
}

fn run(self) {
self.task.run();
}

fn shutdown_or_run_if_mandatory(self) {
match self.mandatory {
Mandatory::NonMandatory => self.task.shutdown(),
Mandatory::Mandatory => self.task.run(),
}
}
}

const KEEP_ALIVE: Duration = Duration::from_secs(10);

/// Runs the provided function on an executor dedicated to blocking operations.
/// Tasks will be scheduled as non-mandatory, meaning they may not get executed
/// in case of runtime shutdown.
pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
Expand All @@ -84,6 +113,25 @@ where
rt.spawn_blocking(func)
}

cfg_fs! {
#[cfg_attr(any(
all(loom, not(test)), // the function is covered by loom tests
test
), allow(dead_code))]
/// Runs the provided function on an executor dedicated to blocking
/// operations. Tasks will be scheduled as mandatory, meaning they are
/// guaranteed to run unless a shutdown is already taking place. In case a
/// shutdown is already taking place, `None` will be returned.
pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let rt = context::current();
rt.spawn_mandatory_blocking(func)
}
}

// ===== impl BlockingPool =====

impl BlockingPool {
Expand Down Expand Up @@ -176,8 +224,10 @@ impl Spawner {
let mut shared = self.inner.shared.lock();

if shared.shutdown {
// Shutdown the task
task.shutdown();
// Shutdown the task: it's fine to shutdown this task (even if
// mandatory) because it was scheduled after the shutdown of the
// runtime began.
task.task.shutdown();

// no need to even push this task; it would never get picked up
return Err(());
Expand Down Expand Up @@ -302,7 +352,8 @@ impl Inner {
// Drain the queue
while let Some(task) = shared.queue.pop_front() {
drop(shared);
task.shutdown();

task.shutdown_or_run_if_mandatory();

shared = self.shared.lock();
}
Expand Down
57 changes: 50 additions & 7 deletions tokio/src/runtime/handle.rs
Expand Up @@ -189,15 +189,56 @@ impl Handle {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(Box::new(func), None)
} else {
self.spawn_blocking_inner(func, None)
let (join_handle, _was_spawned) =
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None)
} else {
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None)
};

join_handle
}

cfg_fs! {
#[track_caller]
#[cfg_attr(any(
all(loom, not(test)), // the function is covered by loom tests
test
), allow(dead_code))]
pub(crate) fn spawn_mandatory_blocking<F, R>(&self, func: F) -> Option<JoinHandle<R>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(
Box::new(func),
blocking::Mandatory::Mandatory,
None
)
} else {
self.spawn_blocking_inner(
func,
blocking::Mandatory::Mandatory,
None
)
};

if was_spawned {
Some(join_handle)
} else {
None
}
}
}

#[track_caller]
pub(crate) fn spawn_blocking_inner<F, R>(&self, func: F, name: Option<&str>) -> JoinHandle<R>
pub(crate) fn spawn_blocking_inner<F, R>(
&self,
func: F,
is_mandatory: blocking::Mandatory,
name: Option<&str>,
) -> (JoinHandle<R>, bool)
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
Expand All @@ -223,8 +264,10 @@ impl Handle {
let _ = name;

let (task, handle) = task::unowned(fut, NoopSchedule);
let _ = self.blocking_spawner.spawn(task, self);
handle
let spawned = self
.blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), self);
(handle, spawned.is_ok())
}

/// Runs a future to completion on this `Handle`'s associated `Runtime`.
Expand Down
8 changes: 8 additions & 0 deletions tokio/src/runtime/mod.rs
Expand Up @@ -201,6 +201,14 @@ cfg_rt! {
use blocking::BlockingPool;
pub(crate) use blocking::spawn_blocking;

cfg_trace! {
pub(crate) use blocking::Mandatory;
}

cfg_fs! {
pub(crate) use blocking::spawn_mandatory_blocking;
}

mod builder;
pub use self::builder::Builder;

Expand Down
50 changes: 50 additions & 0 deletions tokio/src/runtime/tests/loom_blocking.rs
Expand Up @@ -23,6 +23,56 @@ fn blocking_shutdown() {
});
}

#[test]
fn spawn_mandatory_blocking_should_always_run() {
use crate::runtime::tests::loom_oneshot;
loom::model(|| {
let rt = runtime::Builder::new_current_thread().build().unwrap();

let (tx, rx) = loom_oneshot::channel();
let _enter = rt.enter();
runtime::spawn_blocking(|| {});
runtime::spawn_mandatory_blocking(move || {
let _ = tx.send(());
})
.unwrap();

drop(rt);

// This call will deadlock if `spawn_mandatory_blocking` doesn't run.
let () = rx.recv();
});
}

#[test]
fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() {
use crate::runtime::tests::loom_oneshot;
loom::model(|| {
let rt = runtime::Builder::new_current_thread().build().unwrap();
let handle = rt.handle().clone();

// Drop the runtime in a different thread
{
loom::thread::spawn(move || {
drop(rt);
});
}

let _enter = handle.enter();
let (tx, rx) = loom_oneshot::channel();
let handle = runtime::spawn_mandatory_blocking(move || {
let _ = tx.send(());
});

// handle.is_some() means that `spawn_mandatory_blocking`
// promised us to run the blocking task
if handle.is_some() {
// This call will deadlock if `spawn_mandatory_blocking` doesn't run.
let () = rx.recv();
}
});
}

fn mk_runtime(num_threads: usize) -> Runtime {
runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
Expand Down
5 changes: 4 additions & 1 deletion tokio/src/task/builder.rs
Expand Up @@ -107,6 +107,9 @@ impl<'a> Builder<'a> {
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
{
context::current().spawn_blocking_inner(function, self.name)
use crate::runtime::Mandatory;
let (join_handle, _was_spawned) =
context::current().spawn_blocking_inner(function, Mandatory::NonMandatory, self.name);
join_handle
}
}

0 comments on commit 7aad428

Please sign in to comment.