diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 00000000000..df8858986f3 --- /dev/null +++ b/.cargo/config @@ -0,0 +1,2 @@ +# [build] +# rustflags = ["--cfg", "tokio_unstable"] \ No newline at end of file diff --git a/README.md b/README.md index 1cce34aeeff..47522be3cb7 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Make sure you activated the full features of the tokio crate on Cargo.toml: ```toml [dependencies] -tokio = { version = "1.17.0", features = ["full"] } +tokio = { version = "1.18.0", features = ["full"] } ``` Then, on your main.rs: diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index b4557586a0f..992ff234a8b 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -29,13 +29,12 @@ codec = ["tracing"] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] -rt = ["tokio/rt", "tokio/sync", "futures-util"] +rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"] __docs_rs = ["futures-util"] [dependencies] -tokio = { version = "1.6.0", path = "../tokio", features = ["sync"] } - +tokio = { version = "1.18.0", path = "../tokio", features = ["sync"] } bytes = "1.0.0" futures-core = "0.3.0" futures-sink = "0.3.0" @@ -45,6 +44,9 @@ pin-project-lite = "0.2.0" slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } +[target.'cfg(tokio_unstable)'.dependencies] +hashbrown = { version = "0.12.0", optional = true } + [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } tokio-test = { version = "0.4.0", path = "../tokio-test" } diff --git a/tokio-util/src/task/join_map.rs b/tokio-util/src/task/join_map.rs new file mode 100644 index 00000000000..41c82f448ab --- /dev/null +++ b/tokio-util/src/task/join_map.rs @@ -0,0 +1,808 @@ +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use std::borrow::Borrow; +use std::collections::hash_map::RandomState; +use std::fmt; +use std::future::Future; +use std::hash::{BuildHasher, Hash, Hasher}; +use tokio::runtime::Handle; +use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; + +/// A collection of tasks spawned on a Tokio runtime, associated with hash map +/// keys. +/// +/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the +/// addition of a set of keys associated with each task. These keys allow +/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the +/// `JoinMap` based on their keys, or [test whether a task corresponding to a +/// given key exists][contains] in the `JoinMap`. +/// +/// In addition, when tasks in the `JoinMap` complete, they will return the +/// associated key along with the value returned by the task, if any. +/// +/// A `JoinMap` can be used to await the completion of some or all of the tasks +/// in the map. The map is not ordered, and the tasks will be returned in the +/// order they complete. +/// +/// All of the tasks must have the same return type `V`. +/// +/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted. +/// +/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the +/// documentation on unstable features][unstable] for details on how to enable +/// Tokio's unstable features. +/// +/// # Examples +/// +/// Spawn multiple tasks and wait for them: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// for i in 0..10 { +/// // Spawn a task on the `JoinMap` with `i` as its key. +/// map.spawn(i, async move { /* ... */ }); +/// } +/// +/// let mut seen = [false; 10]; +/// +/// // When a task completes, `join_one` returns the task's key along +/// // with its output. +/// while let Some((key, res)) = map.join_one().await { +/// seen[key] = true; +/// assert!(res.is_ok(), "task {} completed successfully!", key); +/// } +/// +/// for i in 0..10 { +/// assert!(seen[i]); +/// } +/// } +/// ``` +/// +/// Cancel tasks based on their keys: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// map.spawn("hello world", async move { /* ... */ }); +/// map.spawn("goodbye world", async move { /* ... */}); +/// +/// // Look up the "goodbye world" task in the map and abort it. +/// let aborted = map.abort("goodbye world"); +/// +/// // `JoinMap::abort` returns `true` if a task existed for the +/// // provided key. +/// assert!(aborted); +/// +/// while let Some((key, res)) = map.join_one().await { +/// if key == "goodbye world" { +/// // The aborted task should complete with a cancelled `JoinError`. +/// assert!(res.unwrap_err().is_cancelled()); +/// } else { +/// // Other tasks should complete normally. +/// assert!(res.is_ok()); +/// } +/// } +/// } +/// ``` +/// +/// [`JoinSet`]: tokio::task::JoinSet +/// [unstable]: tokio#unstable-features +/// [abort]: fn@Self::abort +/// [abort_matching]: fn@Self::abort_matching +/// [contains]: fn@Self::contains_key +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +pub struct JoinMap { + /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, + /// indexed by their keys and task IDs. + /// + /// The [`Key`] type contains both the task's `K`-typed key provided when + /// spawning tasks, and the task's IDs. The IDs are stored here to resolve + /// hash collisions when looking up tasks based on their pre-computed hash + /// (as stored in the `hashes_by_task` map). + tasks_by_key: HashMap, AbortHandle, S>, + + /// A map from task IDs to the hash of the key associated with that task. + /// + /// This map is used to perform reverse lookups of tasks in the + /// `tasks_by_key` map based on their task IDs. When a task terminates, the + /// ID is provided to us by the `JoinSet`, so we can look up the hash value + /// of that task's key, and then remove it from the `tasks_by_key` map using + /// the raw hash code, resolving collisions by comparing task IDs. + hashes_by_task: HashMap, + + /// The [`JoinSet`] that awaits the completion of tasks spawned on this + /// `JoinMap`. + tasks: JoinSet, +} + +/// A [`JoinMap`] key. +/// +/// This holds both a `K`-typed key (the actual key as seen by the user), _and_ +/// a task ID, so that hash collisions between `K`-typed keys can be resolved +/// using either `K`'s `Eq` impl *or* by checking the task IDs. +/// +/// This allows looking up a task using either an actual key (such as when the +/// user queries the map with a key), *or* using a task ID and a hash (such as +/// when removing completed tasks from the map). +#[derive(Debug)] +struct Key { + key: K, + id: Id, +} + +impl JoinMap { + /// Creates a new empty `JoinMap`. + /// + /// The `JoinMap` is initially created with a capacity of 0, so it will not + /// allocate until a task is first spawned on it. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::new(); + /// ``` + #[inline] + #[must_use] + pub fn new() -> Self { + Self::with_hasher(RandomState::new()) + } + + /// Creates an empty `JoinMap` with the specified capacity. + /// + /// The `JoinMap` will be able to hold at least `capacity` tasks without + /// reallocating. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10); + /// ``` + #[inline] + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + JoinMap::with_capacity_and_hasher(capacity, Default::default()) + } +} + +impl JoinMap { + /// Creates an empty `JoinMap` which will use the given hash builder to hash + /// keys. + /// + /// The created map has the default initial capacity. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow `JoinMap` to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap` to be useful, see its documentation for details. + #[inline] + #[must_use] + pub fn with_hasher(hash_builder: S) -> Self { + Self::with_capacity_and_hasher(0, hash_builder) + } + + /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder` + /// to hash the keys. + /// + /// The `JoinMap` will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow HashMaps to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap`to be useful, see its documentation for details. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// use std::collections::hash_map::RandomState; + /// + /// let s = RandomState::new(); + /// let mut map = JoinMap::with_capacity_and_hasher(10, s); + /// map.spawn(1, async move { "hello world!" }); + /// # } + /// ``` + #[inline] + #[must_use] + pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { + Self { + tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), + hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), + tasks: JoinSet::new(), + } + } + + /// Returns the number of tasks currently in the `JoinMap`. + pub fn len(&self) -> usize { + let len = self.tasks_by_key.len(); + debug_assert_eq!(len, self.hashes_by_task.len()); + len + } + + /// Returns whether the `JoinMap` is empty. + pub fn is_empty(&self) -> bool { + let empty = self.tasks_by_key.is_empty(); + debug_assert_eq!(empty, self.hashes_by_task.is_empty()); + empty + } + + /// Returns the number of tasks the map can hold without reallocating. + /// + /// This number is a lower bound; the `JoinMap` might be able to hold + /// more, but is guaranteed to be able to hold at least this many. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let map: JoinMap = JoinMap::with_capacity(100); + /// assert!(map.capacity() >= 100); + /// ``` + #[inline] + pub fn capacity(&self) -> usize { + let capacity = self.tasks_by_key.capacity(); + debug_assert_eq!(capacity, self.hashes_by_task.capacity()); + capacity + } +} + +impl JoinMap +where + K: Hash + Eq, + V: 'static, + S: BuildHasher, +{ + /// Spawn the provided task and store it in this `JoinMap` with the provided + /// key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`join_one`]: Self::join_one + pub fn spawn(&mut self, key: K, task: F) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn(task); + self.insert(key, task) + } + + /// Spawn the provided task on the provided runtime and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`join_one`]: Self::join_one + pub fn spawn_on(&mut self, key: K, task: F, handle: &Handle) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_on(task, handle); + self.insert(key, task); + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_one`]: Self::join_one + pub fn spawn_local(&mut self, key: K, task: F) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local(task); + self.insert(key, task); + } + + /// Spawn the provided task on the provided [`LocalSet`] and store it in + /// this `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_one`]: Self::join_one + pub fn spawn_local_on(&mut self, key: K, task: F, local_set: &LocalSet) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local_on(task, local_set); + self.insert(key, task) + } + + fn insert(&mut self, key: K, abort: AbortHandle) { + let hash = self.hash(&key); + let id = abort.id(); + let map_key = Key { + id: id.clone(), + key, + }; + + // Insert the new key into the map of tasks by keys. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.key == map_key.key); + match entry { + RawEntryMut::Occupied(mut occ) => { + // There was a previous task spawned with the same key! Cancel + // that task, and remove its ID from the map of hashes by task IDs. + let Key { id: prev_id, .. } = occ.insert_key(map_key); + occ.insert(abort).abort(); + let _prev_hash = self.hashes_by_task.remove(&prev_id); + debug_assert_eq!(Some(hash), _prev_hash); + } + RawEntryMut::Vacant(vac) => { + vac.insert(map_key, abort); + } + }; + + // Associate the key's hash with this task's ID, for looking up tasks by ID. + let _prev = self.hashes_by_task.insert(id, hash); + debug_assert!(_prev.is_none(), "no prior task should have had the same ID"); + } + + /// Waits until one of the tasks in the map completes and returns its + /// output, along with the key corresponding to that task. + /// + /// Returns `None` if the map is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one` is used as the event in a [`tokio::select!`] + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinMap`. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has + /// completed. The `value` is the return value of that ask, and `key` is + /// the key associated with the task. + /// * `Some((key, Err(err))` if one of the tasks in this JoinMap` has + /// panicked or been aborted. `key` is the key associated with the task + /// that panicked or was aborted. + /// * `None` if the `JoinMap` is empty. + /// + /// [`tokio::select!`]: tokio::select + pub async fn join_one(&mut self) -> Option<(K, Result)> { + let (res, id) = match self.tasks.join_one_with_id().await { + Ok(task) => { + let (id, output) = task?; + (Ok(output), id) + } + Err(e) => { + let id = e.id(); + (Err(e), id) + } + }; + let key = self.remove_by_id(id)?; + Some((key, res)) + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_one`] in + /// a loop until it returns `None`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// `JoinMap` will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_one`]: fn@Self::join_one + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_one().await.is_some() {} + } + + /// Abort the task corresponding to the provided `key`. + /// + /// If this `JoinMap` contains a task corresponding to `key`, this method + /// will abort that task and return `true`. Otherwise, if no task exists for + /// `key`, this method returns `false`. + /// + /// # Examples + /// + /// Aborting a task by key: + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // Look up the "goodbye world" task in the map and abort it. + /// map.abort("goodbye world"); + /// + /// while let Some((key, res)) = map.join_one().await { + /// if key == "goodbye world" { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(res.is_ok()); + /// } + /// } + /// # } + /// ``` + /// + /// `abort` returns `true` if a task was aborted: + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // A task for the key "goodbye world" should exist in the map: + /// assert!(map.abort("goodbye world")); + /// + /// // Aborting a key that does not exist will return `false`: + /// assert!(!map.abort("goodbye universe")); + /// # } + /// ``` + pub fn abort(&mut self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + match self.get_by_key(key) { + Some((_, handle)) => { + handle.abort(); + true + } + None => false, + } + } + + /// Aborts all tasks with keys matching `predicate`. + /// + /// `predicate` is a function called with a reference to each key in the + /// map. If it returns `true` for a given key, the corresponding task will + /// be cancelled. + /// + /// # Examples + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # // use the current thread rt so that spawned tasks don't + /// # // complete in the background before they can be aborted. + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("hello san francisco", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye universe", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// + /// // Abort all tasks whose keys begin with "goodbye" + /// map.abort_matching(|key| key.starts_with("goodbye")); + /// + /// let mut seen = 0; + /// while let Some((key, res)) = map.join_one().await { + /// seen += 1; + /// if key.starts_with("goodbye") { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(key.starts_with("hello")); + /// assert!(res.is_ok()); + /// } + /// } + /// + /// // All spawned tasks should have completed. + /// assert_eq!(seen, 4); + /// # } + /// ``` + pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) { + // Note: this method iterates over the tasks and keys *without* removing + // any entries, so that the keys from aborted tasks can still be + // returned when calling `join_one` in the future. + for (Key { ref key, .. }, task) in &self.tasks_by_key { + if predicate(key) { + task.abort(); + } + } + } + + /// Returns `true` if this `JoinMap` contains a task for the provided key. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_one`], this method will still return `true`. + /// + /// [`join_one`]: fn@Self::join_one + pub fn contains_key(&self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + self.get_by_key(key).is_some() + } + + /// Returns `true` if this `JoinMap` contains a task with the provided + /// [task ID]. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_one`], this method will still return `true`. + /// + /// [`join_one`]: fn@Self::join_one + /// [task ID]: tokio::task::Id + pub fn contains_task(&self, task: &Id) -> bool { + self.get_by_id(task).is_some() + } + + /// Reserves capacity for at least `additional` more tasks to be spawned + /// on this `JoinMap` without reallocating for the map of task keys. The + /// collection may reserve more space to avoid frequent reallocations. + /// + /// Note that spawning a task will still cause an allocation for the task + /// itself. + /// + /// # Panics + /// + /// Panics if the new allocation size overflows [`usize`]. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap<&str, i32> = JoinMap::new(); + /// map.reserve(10); + /// ``` + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.tasks_by_key.reserve(additional); + self.hashes_by_task.reserve(additional); + } + + /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop + /// down as much as possible while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to_fit(); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to_fit(&mut self) { + self.hashes_by_task.shrink_to_fit(); + self.tasks_by_key.shrink_to_fit(); + } + + /// Shrinks the capacity of the map with a lower limit. It will drop + /// down no lower than the supplied limit while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// If the current capacity is less than the lower limit, this is a no-op. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to(10); + /// assert!(map.capacity() >= 10); + /// map.shrink_to(0); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to(&mut self, min_capacity: usize) { + self.hashes_by_task.shrink_to(min_capacity); + self.tasks_by_key.shrink_to(min_capacity) + } + + /// Look up a task in the map by its key, returning the key and abort handle. + fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key, &'map AbortHandle)> + where + Q: Hash + Eq, + K: Borrow, + { + let hash = self.hash(key); + self.tasks_by_key + .raw_entry() + .from_hash(hash, |k| k.key.borrow() == key) + } + + /// Look up a task in the map by its task ID, returning the key and abort handle. + fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key, &'map AbortHandle)> { + let hash = self.hashes_by_task.get(id)?; + self.tasks_by_key + .raw_entry() + .from_hash(*hash, |k| &k.id == id) + } + + /// Remove a task from the map by ID, returning the key for that task. + fn remove_by_id(&mut self, id: Id) -> Option { + // Get the hash for the given ID. + let hash = self.hashes_by_task.remove(&id)?; + + // Remove the entry for that hash. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.id == id); + let (Key { id: _key_id, key }, handle) = match entry { + RawEntryMut::Occupied(entry) => entry.remove_entry(), + _ => return None, + }; + debug_assert_eq!(_key_id, id); + debug_assert_eq!(id, handle.id()); + self.hashes_by_task.remove(&id); + Some(key) + } + + /// Returns the hash for a given key. + #[inline] + fn hash(&self, key: &Q) -> u64 + where + Q: Hash, + { + let mut hasher = self.tasks_by_key.hasher().build_hasher(); + key.hash(&mut hasher); + hasher.finish() + } +} + +impl JoinMap +where + V: 'static, +{ + /// Aborts all tasks on this `JoinMap`. + /// + /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete + /// cancellation, you should call `join_one` in a loop until the `JoinMap` is empty. + pub fn abort_all(&mut self) { + self.tasks.abort_all() + } + + /// Removes all tasks from this `JoinMap` without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the `JoinMap` + /// is dropped. They may still be aborted by key. + pub fn detach_all(&mut self) { + self.tasks.detach_all(); + self.tasks_by_key.clear(); + self.hashes_by_task.clear(); + } +} + +// Hand-written `fmt::Debug` implementation in order to avoid requiring `V: +// Debug`, since no value is ever actually stored in the map. +impl fmt::Debug for JoinMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // format the task keys and abort handles a little nicer by just + // printing the key and task ID pairs, without format the `Key` struct + // itself or the `AbortHandle`, which would just format the task's ID + // again. + struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap, AbortHandle, S>); + impl fmt::Debug for KeySet<'_, K, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.0.keys().map(|Key { key, id }| (key, id))) + .finish() + } + } + + f.debug_struct("JoinMap") + // The `tasks_by_key` map is the only one that contains information + // that's really worth formatting for the user, since it contains + // the tasks' keys and IDs. The other fields are basically + // implementation details. + .field("tasks", &KeySet(&self.tasks_by_key)) + .finish() + } +} + +impl Default for JoinMap { + fn default() -> Self { + Self::new() + } +} + +// === impl Key === + +impl Hash for Key { + // Don't include the task ID in the hash. + #[inline] + fn hash(&self, hasher: &mut H) { + self.key.hash(hasher); + } +} + +// Because we override `Hash` for this type, we must also override the +// `PartialEq` impl, so that all instances with the same hash are equal. +impl PartialEq for Key { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } +} + +impl Eq for Key {} diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index 5aa33df2dc0..7ba8ad9a218 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -1,4 +1,10 @@ //! Extra utilities for spawning tasks +#[cfg(tokio_unstable)] +mod join_map; mod spawn_pinned; pub use spawn_pinned::LocalPoolHandle; + +#[cfg(tokio_unstable)] +#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] +pub use join_map::JoinMap; diff --git a/tokio-util/tests/task_join_map.rs b/tokio-util/tests/task_join_map.rs new file mode 100644 index 00000000000..d5f87bfb185 --- /dev/null +++ b/tokio-util/tests/task_join_map.rs @@ -0,0 +1,275 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "rt", tokio_unstable))] + +use tokio::sync::oneshot; +use tokio::time::Duration; +use tokio_util::task::JoinMap; + +use futures::future::FutureExt; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +#[tokio::test(start_paused = true)] +async fn test_with_sleep() { + let mut map = JoinMap::new(); + + for i in 0..10 { + map.spawn(i, async move { i }); + assert_eq!(map.len(), 1 + i); + } + map.detach_all(); + assert_eq!(map.len(), 0); + + assert!(matches!(map.join_one().await, None)); + + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + assert_eq!(map.len(), 1 + i); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_one().await, None)); + + // Do it again. + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_one().await, None)); +} + +#[tokio::test] +async fn test_abort_on_drop() { + let mut map = JoinMap::new(); + + let mut recvs = Vec::new(); + + for i in 0..16 { + let (send, recv) = oneshot::channel::<()>(); + recvs.push(recv); + + map.spawn(i, async { + // This task will never complete on its own. + futures::future::pending::<()>().await; + drop(send); + }); + } + + drop(map); + + for recv in recvs { + // The task is aborted soon and we will receive an error. + assert!(recv.await.is_err()); + } +} + +#[tokio::test] +async fn alternating() { + let mut map = JoinMap::new(); + + assert_eq!(map.len(), 0); + map.spawn(1, async {}); + assert_eq!(map.len(), 1); + map.spawn(2, async {}); + assert_eq!(map.len(), 2); + + for i in 0..16 { + let (_, res) = map.join_one().await.unwrap(); + assert!(res.is_ok()); + assert_eq!(map.len(), 1); + map.spawn(i, async {}); + assert_eq!(map.len(), 2); + } +} + +#[tokio::test(start_paused = true)] +async fn abort_by_key() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + for i in 0..16 { + if i % 2 != 0 { + // abort odd-numbered tasks. + map.abort(&i); + } + } + + while let Some((key, res)) = map.join_one().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[tokio::test(start_paused = true)] +async fn abort_by_predicate() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + // abort odd-numbered tasks. + map.abort_matching(|key| key % 2 != 0); + + while let Some((key, res)) = map.join_one().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[test] +fn runtime_gone() { + let mut map = JoinMap::new(); + { + let rt = rt(); + map.spawn_on("key", async { 1 }, rt.handle()); + drop(rt); + } + + let (key, res) = rt().block_on(map.join_one()).unwrap(); + assert_eq!(key, "key"); + assert!(res.unwrap_err().is_cancelled()); +} + +// This ensures that `join_one` works correctly when the coop budget is +// exhausted. +#[tokio::test(flavor = "current_thread")] +async fn join_map_coop() { + // Large enough to trigger coop. + const TASK_NUM: u32 = 1000; + + static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); + + let mut map = JoinMap::new(); + + for i in 0..TASK_NUM { + map.spawn(i, async move { + SEM.add_permits(1); + i + }); + } + + // Wait for all tasks to complete. + // + // Since this is a `current_thread` runtime, there's no race condition + // between the last permit being added and the task completing. + let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); + + let mut count = 0; + let mut coop_count = 0; + loop { + match map.join_one().now_or_never() { + Some(Some((key, Ok(i)))) => assert_eq!(key, i), + Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err), + None => { + coop_count += 1; + tokio::task::yield_now().await; + continue; + } + Some(None) => break, + } + + count += 1; + } + assert!(coop_count >= 1); + assert_eq!(count, TASK_NUM); +} + +#[tokio::test(start_paused = true)] +async fn abort_all() { + let mut map: JoinMap = JoinMap::new(); + + for i in 0..5 { + map.spawn(i, futures::future::pending()); + } + for i in 5..10 { + map.spawn(i, async { + tokio::time::sleep(Duration::from_secs(1)).await; + }); + } + + // The join map will now have 5 pending tasks and 5 ready tasks. + tokio::time::sleep(Duration::from_secs(2)).await; + + map.abort_all(); + assert_eq!(map.len(), 10); + + let mut count = 0; + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + if let Err(err) = res { + assert!(err.is_cancelled()); + } + count += 1; + } + assert_eq!(count, 10); + assert_eq!(map.len(), 0); + for was_seen in &seen { + assert!(was_seen); + } +} diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 6bda46ef662..369186eb678 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -6,7 +6,7 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v1.0.x" git tag. -version = "1.17.0" +version = "1.18.0" edition = "2018" rust-version = "1.49" authors = ["Tokio Contributors "] diff --git a/tokio/README.md b/tokio/README.md index 1cce34aeeff..47522be3cb7 100644 --- a/tokio/README.md +++ b/tokio/README.md @@ -56,7 +56,7 @@ Make sure you activated the full features of the tokio crate on Cargo.toml: ```toml [dependencies] -tokio = { version = "1.17.0", features = ["full"] } +tokio = { version = "1.18.0", features = ["full"] } ``` Then, on your main.rs: diff --git a/tokio/src/runtime/task/abort.rs b/tokio/src/runtime/task/abort.rs index cad639ca0c8..4977377880d 100644 --- a/tokio/src/runtime/task/abort.rs +++ b/tokio/src/runtime/task/abort.rs @@ -43,8 +43,8 @@ impl AbortHandle { // the `AbortHandle` type is only publicly exposed when `tokio_unstable` is // enabled, but it is still defined for testing purposes. #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] - pub fn abort(self) { - if let Some(raw) = self.raw { + pub fn abort(&self) { + if let Some(ref raw) = self.raw { raw.remote_abort(); } }