Skip to content

Commit

Permalink
task: add JoinSet for managing sets of tasks(#4335)
Browse files Browse the repository at this point in the history
Adds `JoinSet` for managing multiple spawned tasks and joining them
in completion order.

Closes: #3903
  • Loading branch information
Darksonn committed Feb 1, 2022
1 parent f602410 commit 1bb4d23
Show file tree
Hide file tree
Showing 15 changed files with 1,052 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Expand Up @@ -143,7 +143,7 @@ wasm-bindgen-test = "0.3.0"
mio-aio = { version = "0.6.0", features = ["tokio"] }

[target.'cfg(loom)'.dev-dependencies]
loom = { version = "0.5", features = ["futures", "checkpoint"] }
loom = { version = "0.5.2", features = ["futures", "checkpoint"] }

[package.metadata.docs.rs]
all-features = true
Expand Down
7 changes: 3 additions & 4 deletions tokio/src/runtime/basic_scheduler.rs
@@ -1,6 +1,6 @@
use crate::future::poll_fn;
use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Mutex;
use crate::loom::sync::{Arc, Mutex};
use crate::park::{Park, Unpark};
use crate::runtime::context::EnterGuard;
use crate::runtime::driver::Driver;
Expand All @@ -16,7 +16,6 @@ use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::sync::atomic::Ordering::{AcqRel, Release};
use std::sync::Arc;
use std::task::Poll::{Pending, Ready};
use std::time::Duration;

Expand Down Expand Up @@ -481,8 +480,8 @@ impl Schedule for Arc<Shared> {
}

impl Wake for Shared {
fn wake(self: Arc<Self>) {
Wake::wake_by_ref(&self)
fn wake(arc_self: Arc<Self>) {
Wake::wake_by_ref(&arc_self)
}

/// Wake by reference
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/runtime/task/harness.rs
Expand Up @@ -164,6 +164,13 @@ where
}
}

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
can_read_output(self.header(), self.trailer(), waker)
}

pub(super) fn drop_join_handle_slow(self) {
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.
Expand Down
12 changes: 11 additions & 1 deletion tokio/src/runtime/task/join.rs
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};

cfg_rt! {
/// An owned permission to join on a task (await its termination).
Expand Down Expand Up @@ -200,6 +200,16 @@ impl<T> JoinHandle<T> {
raw.remote_abort();
}
}

/// Set the waker that is notified when the task completes.
pub(crate) fn set_join_waker(&mut self, waker: &Waker) {
if let Some(raw) = self.raw {
if raw.try_set_join_waker(waker) {
// In this case the task has already completed. We wake the waker immediately.
waker.wake_by_ref();
}
}
}
}

impl<T> Unpin for JoinHandle<T> {}
Expand Down
16 changes: 16 additions & 0 deletions tokio/src/runtime/task/raw.rs
Expand Up @@ -19,6 +19,11 @@ pub(super) struct Vtable {
/// Reads the task output, if complete.
pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool,

/// The join handle has been dropped.
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),

Expand All @@ -35,6 +40,7 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
poll: poll::<T, S>,
dealloc: dealloc::<T, S>,
try_read_output: try_read_output::<T, S>,
try_set_join_waker: try_set_join_waker::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
remote_abort: remote_abort::<T, S>,
shutdown: shutdown::<T, S>,
Expand Down Expand Up @@ -84,6 +90,11 @@ impl RawTask {
(vtable.try_read_output)(self.ptr, dst, waker);
}

pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
let vtable = self.header().vtable;
unsafe { (vtable.try_set_join_waker)(self.ptr, waker) }
}

pub(super) fn drop_join_handle_slow(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
Expand Down Expand Up @@ -129,6 +140,11 @@ unsafe fn try_read_output<T: Future, S: Schedule>(
harness.try_read_output(out, waker);
}

unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool {
let harness = Harness::<T, S>::from_raw(ptr);
harness.try_set_join_waker(waker)
}

unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_join_handle_slow()
Expand Down
82 changes: 82 additions & 0 deletions tokio/src/runtime/tests/loom_join_set.rs
@@ -0,0 +1,82 @@
use crate::runtime::Builder;
use crate::task::JoinSet;

#[test]
fn test_join_set() {
loom::model(|| {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();
let mut set = JoinSet::new();

rt.block_on(async {
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
});

drop(set);
drop(rt);
});
}

#[test]
fn abort_all_during_completion() {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};

// These booleans assert that at least one execution had the task complete first, and that at
// least one execution had the task be cancelled before it completed.
let complete_happened = Arc::new(AtomicBool::new(false));
let cancel_happened = Arc::new(AtomicBool::new(false));

{
let complete_happened = complete_happened.clone();
let cancel_happened = cancel_happened.clone();
loom::model(move || {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();

let mut set = JoinSet::new();

rt.block_on(async {
set.spawn(async { () });
set.abort_all();

match set.join_one().await {
Ok(Some(())) => complete_happened.store(true, SeqCst),
Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst),
Err(err) => panic!("fail: {}", err),
Ok(None) => {
unreachable!("Aborting the task does not remove it from the JoinSet.")
}
}

assert!(matches!(set.join_one().await, Ok(None)));
});

drop(set);
drop(rt);
});
}

assert!(complete_happened.load(SeqCst));
assert!(cancel_happened.load(SeqCst));
}
3 changes: 2 additions & 1 deletion tokio/src/runtime/tests/mod.rs
Expand Up @@ -30,12 +30,13 @@ mod unowned_wrapper {

cfg_loom! {
mod loom_basic_scheduler;
mod loom_local;
mod loom_blocking;
mod loom_local;
mod loom_oneshot;
mod loom_pool;
mod loom_queue;
mod loom_shutdown_join;
mod loom_join_set;
}

cfg_not_loom! {
Expand Down

0 comments on commit 1bb4d23

Please sign in to comment.