From 1bb4d2316232a77d048e8797ba04d796911ffe30 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Tue, 1 Feb 2022 23:17:09 +0100 Subject: [PATCH] task: add `JoinSet` for managing sets of tasks(#4335) Adds `JoinSet` for managing multiple spawned tasks and joining them in completion order. Closes: #3903 --- tokio/Cargo.toml | 2 +- tokio/src/runtime/basic_scheduler.rs | 7 +- tokio/src/runtime/task/harness.rs | 7 + tokio/src/runtime/task/join.rs | 12 +- tokio/src/runtime/task/raw.rs | 16 + tokio/src/runtime/tests/loom_join_set.rs | 82 ++++ tokio/src/runtime/tests/mod.rs | 3 +- tokio/src/task/join_set.rs | 242 ++++++++++++ tokio/src/task/mod.rs | 3 + tokio/src/util/idle_notified_set.rs | 463 +++++++++++++++++++++++ tokio/src/util/linked_list.rs | 5 +- tokio/src/util/mod.rs | 3 + tokio/src/util/wake.rs | 7 +- tokio/tests/async_send_sync.rs | 31 +- tokio/tests/task_join_set.rs | 192 ++++++++++ 15 files changed, 1052 insertions(+), 23 deletions(-) create mode 100644 tokio/src/runtime/tests/loom_join_set.rs create mode 100644 tokio/src/task/join_set.rs create mode 100644 tokio/src/util/idle_notified_set.rs create mode 100644 tokio/tests/task_join_set.rs diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index bd34bb56b0c..bee041d4807 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -143,7 +143,7 @@ wasm-bindgen-test = "0.3.0" mio-aio = { version = "0.6.0", features = ["tokio"] } [target.'cfg(loom)'.dev-dependencies] -loom = { version = "0.5", features = ["futures", "checkpoint"] } +loom = { version = "0.5.2", features = ["futures", "checkpoint"] } [package.metadata.docs.rs] all-features = true diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 00193329137..6eb7e8b2a28 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -1,6 +1,6 @@ use crate::future::poll_fn; use crate::loom::sync::atomic::AtomicBool; -use crate::loom::sync::Mutex; +use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime::context::EnterGuard; use crate::runtime::driver::Driver; @@ -16,7 +16,6 @@ use std::collections::VecDeque; use std::fmt; use std::future::Future; use std::sync::atomic::Ordering::{AcqRel, Release}; -use std::sync::Arc; use std::task::Poll::{Pending, Ready}; use std::time::Duration; @@ -481,8 +480,8 @@ impl Schedule for Arc { } impl Wake for Shared { - fn wake(self: Arc) { - Wake::wake_by_ref(&self) + fn wake(arc_self: Arc) { + Wake::wake_by_ref(&arc_self) } /// Wake by reference diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 530e6b13c54..cb1750dee58 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -164,6 +164,13 @@ where } } + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + can_read_output(self.header(), self.trailer(), waker) + } + pub(super) fn drop_join_handle_slow(self) { // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index cfb2214ca01..8beed2eaacb 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -5,7 +5,7 @@ use std::future::Future; use std::marker::PhantomData; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Waker}; cfg_rt! { /// An owned permission to join on a task (await its termination). @@ -200,6 +200,16 @@ impl JoinHandle { raw.remote_abort(); } } + + /// Set the waker that is notified when the task completes. + pub(crate) fn set_join_waker(&mut self, waker: &Waker) { + if let Some(raw) = self.raw { + if raw.try_set_join_waker(waker) { + // In this case the task has already completed. We wake the waker immediately. + waker.wake_by_ref(); + } + } + } } impl Unpin for JoinHandle {} diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index fbc9574f1a4..30fe4156300 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -19,6 +19,11 @@ pub(super) struct Vtable { /// Reads the task output, if complete. pub(super) try_read_output: unsafe fn(NonNull
, *mut (), &Waker), + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) try_set_join_waker: unsafe fn(NonNull
, &Waker) -> bool, + /// The join handle has been dropped. pub(super) drop_join_handle_slow: unsafe fn(NonNull
), @@ -35,6 +40,7 @@ pub(super) fn vtable() -> &'static Vtable { poll: poll::, dealloc: dealloc::, try_read_output: try_read_output::, + try_set_join_waker: try_set_join_waker::, drop_join_handle_slow: drop_join_handle_slow::, remote_abort: remote_abort::, shutdown: shutdown::, @@ -84,6 +90,11 @@ impl RawTask { (vtable.try_read_output)(self.ptr, dst, waker); } + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + let vtable = self.header().vtable; + unsafe { (vtable.try_set_join_waker)(self.ptr, waker) } + } + pub(super) fn drop_join_handle_slow(self) { let vtable = self.header().vtable; unsafe { (vtable.drop_join_handle_slow)(self.ptr) } @@ -129,6 +140,11 @@ unsafe fn try_read_output( harness.try_read_output(out, waker); } +unsafe fn try_set_join_waker(ptr: NonNull
, waker: &Waker) -> bool { + let harness = Harness::::from_raw(ptr); + harness.try_set_join_waker(waker) +} + unsafe fn drop_join_handle_slow(ptr: NonNull
) { let harness = Harness::::from_raw(ptr); harness.drop_join_handle_slow() diff --git a/tokio/src/runtime/tests/loom_join_set.rs b/tokio/src/runtime/tests/loom_join_set.rs new file mode 100644 index 00000000000..e87ddb01406 --- /dev/null +++ b/tokio/src/runtime/tests/loom_join_set.rs @@ -0,0 +1,82 @@ +use crate::runtime::Builder; +use crate::task::JoinSet; + +#[test] +fn test_join_set() { + loom::model(|| { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let mut set = JoinSet::new(); + + rt.block_on(async { + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + }); + + drop(set); + drop(rt); + }); +} + +#[test] +fn abort_all_during_completion() { + use std::sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }; + + // These booleans assert that at least one execution had the task complete first, and that at + // least one execution had the task be cancelled before it completed. + let complete_happened = Arc::new(AtomicBool::new(false)); + let cancel_happened = Arc::new(AtomicBool::new(false)); + + { + let complete_happened = complete_happened.clone(); + let cancel_happened = cancel_happened.clone(); + loom::model(move || { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + + let mut set = JoinSet::new(); + + rt.block_on(async { + set.spawn(async { () }); + set.abort_all(); + + match set.join_one().await { + Ok(Some(())) => complete_happened.store(true, SeqCst), + Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst), + Err(err) => panic!("fail: {}", err), + Ok(None) => { + unreachable!("Aborting the task does not remove it from the JoinSet.") + } + } + + assert!(matches!(set.join_one().await, Ok(None))); + }); + + drop(set); + drop(rt); + }); + } + + assert!(complete_happened.load(SeqCst)); + assert!(cancel_happened.load(SeqCst)); +} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index be36d6ffe4d..4b49698a86a 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -30,12 +30,13 @@ mod unowned_wrapper { cfg_loom! { mod loom_basic_scheduler; - mod loom_local; mod loom_blocking; + mod loom_local; mod loom_oneshot; mod loom_pool; mod loom_queue; mod loom_shutdown_join; + mod loom_join_set; } cfg_not_loom! { diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs new file mode 100644 index 00000000000..3c5de553ebe --- /dev/null +++ b/tokio/src/task/join_set.rs @@ -0,0 +1,242 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Handle; +use crate::task::{JoinError, JoinHandle, LocalSet}; +use crate::util::IdleNotifiedSet; + +/// A collection of tasks spawned on a Tokio runtime. +/// +/// A `JoinSet` can be used to await the completion of some or all of the tasks +/// in the set. The set is not ordered, and the tasks will be returned in the +/// order they complete. +/// +/// All of the tasks must have the same return type `T`. +/// +/// When the `JoinSet` is dropped, all tasks in the `JoinSet` are immediately aborted. +/// +/// # Examples +/// +/// Spawn multiple tasks and wait for them. +/// +/// ``` +/// use tokio::task::JoinSet; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut set = JoinSet::new(); +/// +/// for i in 0..10 { +/// set.spawn(async move { i }); +/// } +/// +/// let mut seen = [false; 10]; +/// while let Some(res) = set.join_one().await.unwrap() { +/// seen[res] = true; +/// } +/// +/// for i in 0..10 { +/// assert!(seen[i]); +/// } +/// } +/// ``` +pub struct JoinSet { + inner: IdleNotifiedSet>, +} + +impl JoinSet { + /// Create a new `JoinSet`. + pub fn new() -> Self { + Self { + inner: IdleNotifiedSet::new(), + } + } + + /// Returns the number of tasks currently in the `JoinSet`. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns whether the `JoinSet` is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl JoinSet { + /// Spawn the provided task on the `JoinSet`. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + pub fn spawn(&mut self, task: F) + where + F: Future, + F: Send + 'static, + T: Send, + { + self.insert(crate::spawn(task)); + } + + /// Spawn the provided task on the provided runtime and store it in this `JoinSet`. + pub fn spawn_on(&mut self, task: F, handle: &Handle) + where + F: Future, + F: Send + 'static, + T: Send, + { + self.insert(handle.spawn(task)); + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn spawn_local(&mut self, task: F) + where + F: Future, + F: 'static, + { + self.insert(crate::task::spawn_local(task)); + } + + /// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn spawn_local_on(&mut self, task: F, local_set: &LocalSet) + where + F: Future, + F: 'static, + { + self.insert(local_set.spawn_local(task)); + } + + fn insert(&mut self, jh: JoinHandle) { + let mut entry = self.inner.insert_idle(jh); + + // Set the waker that is notified when the task completes. + entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker())); + } + + /// Waits until one of the tasks in the set completes and returns its output. + /// + /// Returns `None` if the set is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinSet`. + pub async fn join_one(&mut self) -> Result, JoinError> { + crate::future::poll_fn(|cx| self.poll_join_one(cx)).await + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_one`] in + /// a loop until it returns `Ok(None)`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// `JoinSet` will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_one`]: fn@Self::join_one + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_one().await.transpose().is_some() {} + } + + /// Aborts all tasks on this `JoinSet`. + /// + /// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete + /// cancellation, you should call `join_one` in a loop until the `JoinSet` is empty. + pub fn abort_all(&mut self) { + self.inner.for_each(|jh| jh.abort()); + } + + /// Removes all tasks from this `JoinSet` without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the `JoinSet` + /// is dropped. + pub fn detach_all(&mut self) { + self.inner.drain(drop); + } + + /// Polls for one of the tasks in the set to complete. + /// + /// If this returns `Poll::Ready(Ok(Some(_)))` or `Poll::Ready(Err(_))`, then the task that + /// completed is removed from the set. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled + /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to + /// `poll_join_one`, only the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is + /// available right now. + /// * `Poll::Ready(Ok(Some(value)))` if one of the tasks in this `JoinSet` has completed. The + /// `value` is the return value of one of the tasks that completed. + /// * `Poll::Ready(Err(err))` if one of the tasks in this `JoinSet` has panicked or been + /// aborted. + /// * `Poll::Ready(Ok(None))` if the `JoinSet` is empty. + /// + /// Note that this method may return `Poll::Pending` even if one of the tasks has completed. + /// This can happen if the [coop budget] is reached. + /// + /// [coop budget]: crate::task#cooperative-scheduling + fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll, JoinError>> { + // The call to `pop_notified` moves the entry to the `idle` list. It is moved back to + // the `notified` list if the waker is notified in the `poll` call below. + let mut entry = match self.inner.pop_notified(cx.waker()) { + Some(entry) => entry, + None => { + if self.is_empty() { + return Poll::Ready(Ok(None)); + } else { + // The waker was set by `pop_notified`. + return Poll::Pending; + } + } + }; + + let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx)); + + if let Poll::Ready(res) = res { + entry.remove(); + Poll::Ready(Some(res).transpose()) + } else { + // A JoinHandle generally won't emit a wakeup without being ready unless + // the coop limit has been reached. We yield to the executor in this + // case. + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +impl Drop for JoinSet { + fn drop(&mut self) { + self.inner.drain(|join_handle| join_handle.abort()); + } +} + +impl fmt::Debug for JoinSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinSet").field("len", &self.len()).finish() + } +} + +impl Default for JoinSet { + fn default() -> Self { + Self::new() + } +} diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index ea9878736db..7d254190148 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -300,6 +300,9 @@ cfg_rt! { mod unconstrained; pub use unconstrained::{unconstrained, Unconstrained}; + mod join_set; + pub use join_set::JoinSet; + cfg_trace! { mod builder; pub use builder::Builder; diff --git a/tokio/src/util/idle_notified_set.rs b/tokio/src/util/idle_notified_set.rs new file mode 100644 index 00000000000..71f3a32a853 --- /dev/null +++ b/tokio/src/util/idle_notified_set.rs @@ -0,0 +1,463 @@ +//! This module defines an `IdleNotifiedSet`, which is a collection of elements. +//! Each element is intended to correspond to a task, and the collection will +//! keep track of which tasks have had their waker notified, and which have not. +//! +//! Each entry in the set holds some user-specified value. The value's type is +//! specified using the `T` parameter. It will usually be a `JoinHandle` or +//! similar. + +use std::marker::PhantomPinned; +use std::mem::ManuallyDrop; +use std::ptr::NonNull; +use std::task::{Context, Waker}; + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::linked_list::{self, Link}; +use crate::util::{waker_ref, Wake}; + +type LinkedList = + linked_list::LinkedList, as linked_list::Link>::Target>; + +/// This is the main handle to the collection. +pub(crate) struct IdleNotifiedSet { + lists: Arc>, + length: usize, +} + +/// A handle to an entry that is guaranteed to be stored in the idle or notified +/// list of its `IdleNotifiedSet`. This value borrows the `IdleNotifiedSet` +/// mutably to prevent the entry from being moved to the `Neither` list, which +/// only the `IdleNotifiedSet` may do. +/// +/// The main consequence of being stored in one of the lists is that the `value` +/// field has not yet been consumed. +/// +/// Note: This entry can be moved from the idle to the notified list while this +/// object exists by waking its waker. +pub(crate) struct EntryInOneOfTheLists<'a, T> { + entry: Arc>, + set: &'a mut IdleNotifiedSet, +} + +type Lists = Mutex>; + +/// The linked lists hold strong references to the ListEntry items, and the +/// ListEntry items also hold a strong reference back to the Lists object, but +/// the destructor of the `IdleNotifiedSet` will clear the two lists, so once +/// that object is destroyed, no ref-cycles will remain. +struct ListsInner { + notified: LinkedList, + idle: LinkedList, + /// Whenever an element in the `notified` list is woken, this waker will be + /// notified and consumed, if it exists. + waker: Option, +} + +/// Which of the two lists in the shared Lists object is this entry stored in? +/// +/// If the value is `Idle`, then an entry's waker may move it to the notified +/// list. Otherwise, only the `IdleNotifiedSet` may move it. +/// +/// If the value is `Neither`, then it is still possible that the entry is in +/// some third external list (this happens in `drain`). +#[derive(Copy, Clone, Eq, PartialEq)] +enum List { + Notified, + Idle, + Neither, +} + +/// An entry in the list. +/// +/// # Safety +/// +/// The `my_list` field must only be accessed while holding the mutex in +/// `parent`. It is an invariant that the value of `my_list` corresponds to +/// which linked list in the `parent` holds this entry. Once this field takes +/// the value `Neither`, then it may never be modified again. +/// +/// If the value of `my_list` is `Notified` or `Idle`, then the `pointers` field +/// must only be accessed while holding the mutex. If the value of `my_list` is +/// `Neither`, then the `pointers` field may be accessed by the +/// `IdleNotifiedSet` (this happens inside `drain`). +/// +/// The `value` field is owned by the `IdleNotifiedSet` and may only be accessed +/// by the `IdleNotifiedSet`. The operation that sets the value of `my_list` to +/// `Neither` assumes ownership of the `value`, and it must either drop it or +/// move it out from this entry to prevent it from getting leaked. (Since the +/// two linked lists are emptied in the destructor of `IdleNotifiedSet`, the +/// value should not be leaked.) +/// +/// This type is `#[repr(C)]` because its `linked_list::Link` implementation +/// requires that `pointers` is the first field. +#[repr(C)] +struct ListEntry { + /// The linked list pointers of the list this entry is in. + pointers: linked_list::Pointers>, + /// Pointer to the shared `Lists` struct. + parent: Arc>, + /// The value stored in this entry. + value: UnsafeCell>, + /// Used to remember which list this entry is in. + my_list: UnsafeCell, + /// Required by the `linked_list::Pointers` field. + _pin: PhantomPinned, +} + +// With mutable access to the `IdleNotifiedSet`, you can get mutable access to +// the values. +unsafe impl Send for IdleNotifiedSet {} +// With the current API we strictly speaking don't even need `T: Sync`, but we +// require it anyway to support adding &self APIs that access the values in the +// future. +unsafe impl Sync for IdleNotifiedSet {} + +// These impls control when it is safe to create a Waker. Since the waker does +// not allow access to the value in any way (including its destructor), it is +// not necessary for `T` to be Send or Sync. +unsafe impl Send for ListEntry {} +unsafe impl Sync for ListEntry {} + +impl IdleNotifiedSet { + /// Create a new IdleNotifiedSet. + pub(crate) fn new() -> Self { + let lists = Mutex::new(ListsInner { + notified: LinkedList::new(), + idle: LinkedList::new(), + waker: None, + }); + + IdleNotifiedSet { + lists: Arc::new(lists), + length: 0, + } + } + + pub(crate) fn len(&self) -> usize { + self.length + } + + pub(crate) fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Insert the given value into the `idle` list. + pub(crate) fn insert_idle(&mut self, value: T) -> EntryInOneOfTheLists<'_, T> { + self.length += 1; + + let entry = Arc::new(ListEntry { + parent: self.lists.clone(), + value: UnsafeCell::new(ManuallyDrop::new(value)), + my_list: UnsafeCell::new(List::Idle), + pointers: linked_list::Pointers::new(), + _pin: PhantomPinned, + }); + + { + let mut lock = self.lists.lock(); + lock.idle.push_front(entry.clone()); + } + + // Safety: We just put the entry in the idle list, so it is in one of the lists. + EntryInOneOfTheLists { entry, set: self } + } + + /// Pop an entry from the notified list to poll it. The entry is moved to + /// the idle list atomically. + pub(crate) fn pop_notified(&mut self, waker: &Waker) -> Option> { + // We don't decrement the length because this call moves the entry to + // the idle list rather than removing it. + if self.length == 0 { + // Fast path. + return None; + } + + let mut lock = self.lists.lock(); + + let should_update_waker = match lock.waker.as_mut() { + Some(cur_waker) => !waker.will_wake(cur_waker), + None => true, + }; + if should_update_waker { + lock.waker = Some(waker.clone()); + } + + // Pop the entry, returning None if empty. + let entry = lock.notified.pop_back()?; + + lock.idle.push_front(entry.clone()); + + // Safety: We are holding the lock. + entry.my_list.with_mut(|ptr| unsafe { + *ptr = List::Idle; + }); + + drop(lock); + + // Safety: We just put the entry in the idle list, so it is in one of the lists. + Some(EntryInOneOfTheLists { entry, set: self }) + } + + /// Call a function on every element in this list. + pub(crate) fn for_each(&mut self, mut func: F) { + fn get_ptrs(list: &mut LinkedList, ptrs: &mut Vec<*mut T>) { + let mut node = list.last(); + + while let Some(entry) = node { + ptrs.push(entry.value.with_mut(|ptr| { + let ptr: *mut ManuallyDrop = ptr; + let ptr: *mut T = ptr.cast(); + ptr + })); + + let prev = entry.pointers.get_prev(); + node = prev.map(|prev| unsafe { &*prev.as_ptr() }); + } + } + + // Atomically get a raw pointer to the value of every entry. + // + // Since this only locks the mutex once, it is not possible for a value + // to get moved from the idle list to the notified list during the + // operation, which would otherwise result in some value being listed + // twice. + let mut ptrs = Vec::with_capacity(self.len()); + { + let mut lock = self.lists.lock(); + + get_ptrs(&mut lock.idle, &mut ptrs); + get_ptrs(&mut lock.notified, &mut ptrs); + } + debug_assert_eq!(ptrs.len(), ptrs.capacity()); + + for ptr in ptrs { + // Safety: When we grabbed the pointers, the entries were in one of + // the two lists. This means that their value was valid at the time, + // and it must still be valid because we are the IdleNotifiedSet, + // and only we can remove an entry from the two lists. (It's + // possible that an entry is moved from one list to the other during + // this loop, but that is ok.) + func(unsafe { &mut *ptr }); + } + } + + /// Remove all entries in both lists, applying some function to each element. + /// + /// The closure is called on all elements even if it panics. Having it panic + /// twice is a double-panic, and will abort the application. + pub(crate) fn drain(&mut self, func: F) { + if self.length == 0 { + // Fast path. + return; + } + self.length = 0; + + // The LinkedList is not cleared on panic, so we use a bomb to clear it. + // + // This value has the invariant that any entry in its `all_entries` list + // has `my_list` set to `Neither` and that the value has not yet been + // dropped. + struct AllEntries { + all_entries: LinkedList, + func: F, + } + + impl AllEntries { + fn pop_next(&mut self) -> bool { + if let Some(entry) = self.all_entries.pop_back() { + // Safety: We just took this value from the list, so we can + // destroy the value in the entry. + entry + .value + .with_mut(|ptr| unsafe { (self.func)(ManuallyDrop::take(&mut *ptr)) }); + true + } else { + false + } + } + } + + impl Drop for AllEntries { + fn drop(&mut self) { + while self.pop_next() {} + } + } + + let mut all_entries = AllEntries { + all_entries: LinkedList::new(), + func, + }; + + // Atomically move all entries to the new linked list in the AllEntries + // object. + { + let mut lock = self.lists.lock(); + unsafe { + // Safety: We are holding the lock and `all_entries` is a new + // LinkedList. + move_to_new_list(&mut lock.idle, &mut all_entries.all_entries); + move_to_new_list(&mut lock.notified, &mut all_entries.all_entries); + } + } + + // Keep destroying entries in the list until it is empty. + // + // If the closure panics, then the destructor of the `AllEntries` bomb + // ensures that we keep running the destructor on the remaining values. + // A second panic will abort the program. + while all_entries.pop_next() {} + } +} + +/// # Safety +/// +/// The mutex for the entries must be held, and the target list must be such +/// that setting `my_list` to `Neither` is ok. +unsafe fn move_to_new_list(from: &mut LinkedList, to: &mut LinkedList) { + while let Some(entry) = from.pop_back() { + entry.my_list.with_mut(|ptr| { + *ptr = List::Neither; + }); + to.push_front(entry); + } +} + +impl<'a, T> EntryInOneOfTheLists<'a, T> { + /// Remove this entry from the list it is in, returning the value associated + /// with the entry. + /// + /// This consumes the value, since it is no longer guaranteed to be in a + /// list. + pub(crate) fn remove(self) -> T { + self.set.length -= 1; + + { + let mut lock = self.set.lists.lock(); + + // Safety: We are holding the lock so there is no race, and we will + // remove the entry afterwards to uphold invariants. + let old_my_list = self.entry.my_list.with_mut(|ptr| unsafe { + let old_my_list = *ptr; + *ptr = List::Neither; + old_my_list + }); + + let list = match old_my_list { + List::Idle => &mut lock.idle, + List::Notified => &mut lock.notified, + // An entry in one of the lists is in one of the lists. + List::Neither => unreachable!(), + }; + + unsafe { + // Safety: We just checked that the entry is in this particular + // list. + list.remove(ListEntry::as_raw(&self.entry)).unwrap(); + } + } + + // By setting `my_list` to `Neither`, we have taken ownership of the + // value. We return it to the caller. + // + // Safety: We have a mutable reference to the `IdleNotifiedSet` that + // owns this entry, so we can use its permission to access the value. + self.entry + .value + .with_mut(|ptr| unsafe { ManuallyDrop::take(&mut *ptr) }) + } + + /// Access the value in this entry together with a context for its waker. + pub(crate) fn with_value_and_context(&mut self, func: F) -> U + where + F: FnOnce(&mut T, &mut Context<'_>) -> U, + T: 'static, + { + let waker = waker_ref(&self.entry); + + let mut context = Context::from_waker(&waker); + + // Safety: We have a mutable reference to the `IdleNotifiedSet` that + // owns this entry, so we can use its permission to access the value. + self.entry + .value + .with_mut(|ptr| unsafe { func(&mut *ptr, &mut context) }) + } +} + +impl Drop for IdleNotifiedSet { + fn drop(&mut self) { + // Clear both lists. + self.drain(drop); + + #[cfg(debug_assertions)] + if !std::thread::panicking() { + let lock = self.lists.lock(); + assert!(lock.idle.is_empty()); + assert!(lock.notified.is_empty()); + } + } +} + +impl Wake for ListEntry { + fn wake_by_ref(me: &Arc) { + let mut lock = me.parent.lock(); + + // Safety: We are holding the lock and we will update the lists to + // maintain invariants. + let old_my_list = me.my_list.with_mut(|ptr| unsafe { + let old_my_list = *ptr; + if old_my_list == List::Idle { + *ptr = List::Notified; + } + old_my_list + }); + + if old_my_list == List::Idle { + // We move ourself to the notified list. + let me = unsafe { + // Safety: We just checked that we are in this particular list. + lock.idle.remove(NonNull::from(&**me)).unwrap() + }; + lock.notified.push_front(me); + + if let Some(waker) = lock.waker.take() { + drop(lock); + waker.wake(); + } + } + } + + fn wake(me: Arc) { + Self::wake_by_ref(&me) + } +} + +/// # Safety +/// +/// `ListEntry` is forced to be !Unpin. +unsafe impl linked_list::Link for ListEntry { + type Handle = Arc>; + type Target = ListEntry; + + fn as_raw(handle: &Self::Handle) -> NonNull> { + let ptr: *const ListEntry = Arc::as_ptr(handle); + // Safety: We can't get a null pointer from `Arc::as_ptr`. + unsafe { NonNull::new_unchecked(ptr as *mut ListEntry) } + } + + unsafe fn from_raw(ptr: NonNull>) -> Arc> { + Arc::from_raw(ptr.as_ptr()) + } + + unsafe fn pointers( + target: NonNull>, + ) -> NonNull>> { + // Safety: The pointers struct is the first field and ListEntry is + // `#[repr(C)]` so this cast is safe. + // + // We do this rather than doing a field access since `std::ptr::addr_of` + // is too new for our MSRV. + target.cast() + } +} diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index b84bcdcf7a5..cd74d376ae3 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -219,6 +219,7 @@ impl fmt::Debug for LinkedList { #[cfg(any( feature = "fs", + feature = "rt", all(unix, feature = "process"), feature = "signal", feature = "sync", @@ -296,7 +297,7 @@ impl Pointers { } } - fn get_prev(&self) -> Option> { + pub(crate) fn get_prev(&self) -> Option> { // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. unsafe { let inner = self.inner.get(); @@ -304,7 +305,7 @@ impl Pointers { ptr::read(prev) } } - fn get_next(&self) -> Option> { + pub(crate) fn get_next(&self) -> Option> { // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. unsafe { let inner = self.inner.get(); diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index f0a79a7cca9..41a3bce051f 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -44,6 +44,9 @@ pub(crate) mod linked_list; mod rand; cfg_rt! { + mod idle_notified_set; + pub(crate) use idle_notified_set::IdleNotifiedSet; + mod wake; pub(crate) use wake::WakerRef; pub(crate) use wake::{waker_ref, Wake}; diff --git a/tokio/src/util/wake.rs b/tokio/src/util/wake.rs index 8f89668c61a..bf1d4bdf92a 100644 --- a/tokio/src/util/wake.rs +++ b/tokio/src/util/wake.rs @@ -1,13 +1,14 @@ +use crate::loom::sync::Arc; + use std::marker::PhantomData; use std::mem::ManuallyDrop; use std::ops::Deref; -use std::sync::Arc; use std::task::{RawWaker, RawWakerVTable, Waker}; /// Simplified waking interface based on Arcs. -pub(crate) trait Wake: Send + Sync { +pub(crate) trait Wake: Send + Sync + Sized + 'static { /// Wake by value. - fn wake(self: Arc); + fn wake(arc_self: Arc); /// Wake by reference. fn wake_by_ref(arc_self: &Arc); diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index aa14970eb1b..c7e15c4775e 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -435,24 +435,33 @@ async_assert_fn!(tokio::sync::watch::Sender::closed(_): !Send & !Sync & !Unp async_assert_fn!(tokio::sync::watch::Sender::closed(_): !Send & !Sync & !Unpin); async_assert_fn!(tokio::sync::watch::Sender::closed(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey>::scope(_, Cell, BoxFutureSync<()>): Send & !Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey>::scope(_, Cell, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet>::join_one(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet>::shutdown(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet>::join_one(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet>::shutdown(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet::join_one(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::task::JoinSet::shutdown(_): Send & Sync & !Unpin); async_assert_fn!(tokio::task::LocalKey>::scope(_, Cell, BoxFuture<()>): !Send & !Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSync<()>): !Send & !Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSend<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey>::scope(_, Cell, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey>::scope(_, Cell, BoxFutureSync<()>): Send & !Sync & !Unpin); async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSend<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey>::scope(_, Rc, BoxFutureSync<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin); async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync & !Unpin); async_assert_fn!(tokio::task::unconstrained(BoxFuture<()>): !Send & !Sync & Unpin); async_assert_fn!(tokio::task::unconstrained(BoxFutureSend<()>): Send & !Sync & Unpin); async_assert_fn!(tokio::task::unconstrained(BoxFutureSync<()>): Send & Sync & Unpin); -assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin); -assert_value!(tokio::task::JoinHandle: Send & Sync & Unpin); -assert_value!(tokio::task::JoinHandle: Send & Sync & Unpin); -assert_value!(tokio::task::JoinHandle: !Send & !Sync & Unpin); assert_value!(tokio::task::JoinError: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle: !Send & !Sync & Unpin); +assert_value!(tokio::task::JoinHandle: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle: Send & Sync & Unpin); +assert_value!(tokio::task::JoinSet: Send & Sync & Unpin); +assert_value!(tokio::task::JoinSet: Send & Sync & Unpin); +assert_value!(tokio::task::JoinSet: !Send & !Sync & Unpin); +assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin); assert_value!(tokio::runtime::Builder: Send & Sync & Unpin); assert_value!(tokio::runtime::EnterGuard<'_>: Send & Sync & Unpin); diff --git a/tokio/tests/task_join_set.rs b/tokio/tests/task_join_set.rs new file mode 100644 index 00000000000..574f216620e --- /dev/null +++ b/tokio/tests/task_join_set.rs @@ -0,0 +1,192 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::sync::oneshot; +use tokio::task::JoinSet; +use tokio::time::Duration; + +use futures::future::FutureExt; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +#[tokio::test(start_paused = true)] +async fn test_with_sleep() { + let mut set = JoinSet::new(); + + for i in 0..10 { + set.spawn(async move { i }); + assert_eq!(set.len(), 1 + i); + } + set.detach_all(); + assert_eq!(set.len(), 0); + + assert!(matches!(set.join_one().await, Ok(None))); + + for i in 0..10 { + set.spawn(async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + assert_eq!(set.len(), 1 + i); + } + + let mut seen = [false; 10]; + while let Some(res) = set.join_one().await.unwrap() { + seen[res] = true; + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(set.join_one().await, Ok(None))); + + // Do it again. + for i in 0..10 { + set.spawn(async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + } + + let mut seen = [false; 10]; + while let Some(res) = set.join_one().await.unwrap() { + seen[res] = true; + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(set.join_one().await, Ok(None))); +} + +#[tokio::test] +async fn test_abort_on_drop() { + let mut set = JoinSet::new(); + + let mut recvs = Vec::new(); + + for _ in 0..16 { + let (send, recv) = oneshot::channel::<()>(); + recvs.push(recv); + + set.spawn(async { + // This task will never complete on its own. + futures::future::pending::<()>().await; + drop(send); + }); + } + + drop(set); + + for recv in recvs { + // The task is aborted soon and we will receive an error. + assert!(recv.await.is_err()); + } +} + +#[tokio::test] +async fn alternating() { + let mut set = JoinSet::new(); + + assert_eq!(set.len(), 0); + set.spawn(async {}); + assert_eq!(set.len(), 1); + set.spawn(async {}); + assert_eq!(set.len(), 2); + + for _ in 0..16 { + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + set.spawn(async {}); + assert_eq!(set.len(), 2); + } +} + +#[test] +fn runtime_gone() { + let mut set = JoinSet::new(); + { + let rt = rt(); + set.spawn_on(async { 1 }, rt.handle()); + drop(rt); + } + + assert!(rt().block_on(set.join_one()).unwrap_err().is_cancelled()); +} + +// This ensures that `join_one` works correctly when the coop budget is +// exhausted. +#[tokio::test(flavor = "current_thread")] +async fn join_set_coop() { + // Large enough to trigger coop. + const TASK_NUM: u32 = 1000; + + static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); + + let mut set = JoinSet::new(); + + for _ in 0..TASK_NUM { + set.spawn(async { + SEM.add_permits(1); + }); + } + + // Wait for all tasks to complete. + // + // Since this is a `current_thread` runtime, there's no race condition + // between the last permit being added and the task completing. + let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); + + let mut count = 0; + let mut coop_count = 0; + loop { + match set.join_one().now_or_never() { + Some(Ok(Some(()))) => {} + Some(Err(err)) => panic!("failed: {}", err), + None => { + coop_count += 1; + tokio::task::yield_now().await; + continue; + } + Some(Ok(None)) => break, + } + + count += 1; + } + assert!(coop_count >= 1); + assert_eq!(count, TASK_NUM); +} + +#[tokio::test(start_paused = true)] +async fn abort_all() { + let mut set: JoinSet<()> = JoinSet::new(); + + for _ in 0..5 { + set.spawn(futures::future::pending()); + } + for _ in 0..5 { + set.spawn(async { + tokio::time::sleep(Duration::from_secs(1)).await; + }); + } + + // The join set will now have 5 pending tasks and 5 ready tasks. + tokio::time::sleep(Duration::from_secs(2)).await; + + set.abort_all(); + assert_eq!(set.len(), 10); + + let mut count = 0; + while let Some(res) = set.join_one().await.transpose() { + if let Err(err) = res { + assert!(err.is_cancelled()); + } + count += 1; + } + assert_eq!(count, 10); + assert_eq!(set.len(), 0); +}