From 1d3f12304e29dbb8fe3ef4089d01d5acb9edaae1 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Mon, 25 Apr 2022 10:31:19 -0700 Subject: [PATCH] task: add task IDs (#4630) ## Motivation PR #4538 adds a prototype implementation of a `JoinMap` API in `tokio::task`. In [this comment][1] on that PR, @carllerche pointed out that a much simpler `JoinMap` type could be implemented outside of `tokio` (either in `tokio-util` or in user code) if we just modified `JoinSet` to return a task ID type when spawning new tasks, and when tasks complete. This seems like a better approach for the following reasons: * A `JoinMap`-like type need not become a permanent part of `tokio`'s stable API * Task IDs seem like something that could be generally useful outside of a `JoinMap` implementation ## Solution This branch adds a `tokio::task::Id` type that uniquely identifies a task relative to all other spawned tasks. Task IDs are assigned sequentially based on an atomic `usize` counter of spawned tasks. In addition, I modified `JoinSet` to add a `join_with_id` method that behaves identically to `join_one` but also returns an ID. This can be used to implement a `JoinMap` type. Note that because `join_with_id` must return a task ID regardless of whether the task completes successfully or returns a `JoinError`, I've also changed `JoinError` to carry the ID of the task that errored, and added a `JoinError::id` method for accessing it. Alternatively, we could have done one of the following: * have `join_with_id` return `Option<(Id, Result)>`, which would be inconsistent with the return type of `join_one` (which we've [already bikeshedded over once][2]...) * have `join_with_id` return `Result, (Id, JoinError)>>`, which just feels gross. I thought adding the task ID to `JoinError` was the nicest option, and is potentially useful for other stuff as well, so it's probably a good API to have anyway. [1]: https://github.com/tokio-rs/tokio/pull/4538#issuecomment-1065614755 [2]: https://github.com/tokio-rs/tokio/pull/4335#discussion_r773377901 Closes #4538 Signed-off-by: Eliza Weisman --- tokio/Cargo.toml | 2 +- tokio/src/runtime/basic_scheduler.rs | 4 +- tokio/src/runtime/handle.rs | 13 +++-- tokio/src/runtime/mod.rs | 2 +- tokio/src/runtime/spawner.rs | 7 ++- tokio/src/runtime/task/abort.rs | 26 +++++++-- tokio/src/runtime/task/core.rs | 8 ++- tokio/src/runtime/task/error.rs | 32 ++++++++--- tokio/src/runtime/task/harness.rs | 23 ++++---- tokio/src/runtime/task/join.rs | 27 +++++++-- tokio/src/runtime/task/list.rs | 6 +- tokio/src/runtime/task/mod.rs | 75 +++++++++++++++++++++++-- tokio/src/runtime/task/raw.rs | 6 +- tokio/src/runtime/tests/mod.rs | 6 +- tokio/src/runtime/tests/task.rs | 12 +++- tokio/src/runtime/thread_pool/mod.rs | 6 +- tokio/src/runtime/thread_pool/worker.rs | 8 ++- tokio/src/task/join_set.rs | 39 ++++++++++--- tokio/src/task/local.rs | 13 +++-- tokio/src/task/mod.rs | 2 +- tokio/src/task/spawn.rs | 8 ++- tokio/src/util/trace.rs | 5 +- 22 files changed, 252 insertions(+), 78 deletions(-) diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 69ec3197e46..6bda46ef662 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -65,7 +65,7 @@ process = [ "winapi/threadpoollegacyapiset", ] # Includes basic task execution capabilities -rt = [] +rt = ["once_cell"] rt-multi-thread = [ "num_cpus", "rt", diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index acebd0ab480..cb48e98ef22 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -370,12 +370,12 @@ impl Context { impl Spawner { /// Spawns a future onto the basic scheduler - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: super::task::Id) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone(), id); if let Some(notified) = notified { self.shared.schedule(notified); diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 180c3ab859e..9d4a35e5e48 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -175,9 +175,10 @@ impl Handle { F: Future + Send + 'static, F::Output: Send + 'static, { + let id = crate::runtime::task::Id::next(); #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "task", None); - self.spawner.spawn(future) + let future = crate::util::trace::task(future, "task", None, id.as_u64()); + self.spawner.spawn(future, id) } /// Runs the provided function on an executor dedicated to blocking. @@ -285,7 +286,8 @@ impl Handle { #[track_caller] pub fn block_on(&self, future: F) -> F::Output { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "block_on", None); + let future = + crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64()); // Enter the **runtime** context. This configures spawning, the current I/O driver, ... let _rt_enter = self.enter(); @@ -388,7 +390,7 @@ impl HandleInner { R: Send + 'static, { let fut = BlockingTask::new(func); - + let id = super::task::Id::next(); #[cfg(all(tokio_unstable, feature = "tracing"))] let fut = { use tracing::Instrument; @@ -398,6 +400,7 @@ impl HandleInner { "runtime.spawn", kind = %"blocking", task.name = %name.unwrap_or_default(), + task.id = id.as_u64(), "fn" = %std::any::type_name::(), spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), ); @@ -407,7 +410,7 @@ impl HandleInner { #[cfg(not(all(tokio_unstable, feature = "tracing")))] let _ = name; - let (task, handle) = task::unowned(fut, NoopSchedule); + let (task, handle) = task::unowned(fut, NoopSchedule, id); let spawned = self .blocking_spawner .spawn(blocking::Task::new(task, is_mandatory), rt); diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index d7f54360236..bd428525d00 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -467,7 +467,7 @@ cfg_rt! { #[track_caller] pub fn block_on(&self, future: F) -> F::Output { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "block_on", None); + let future = crate::util::trace::task(future, "block_on", None, task::Id::next().as_u64()); let _enter = self.enter(); diff --git a/tokio/src/runtime/spawner.rs b/tokio/src/runtime/spawner.rs index 1dba8e3cef5..fb4d7f91845 100644 --- a/tokio/src/runtime/spawner.rs +++ b/tokio/src/runtime/spawner.rs @@ -1,4 +1,5 @@ use crate::future::Future; +use crate::runtime::task::Id; use crate::runtime::{basic_scheduler, HandleInner}; use crate::task::JoinHandle; @@ -23,15 +24,15 @@ impl Spawner { } } - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { match self { - Spawner::Basic(spawner) => spawner.spawn(future), + Spawner::Basic(spawner) => spawner.spawn(future, id), #[cfg(feature = "rt-multi-thread")] - Spawner::ThreadPool(spawner) => spawner.spawn(future), + Spawner::ThreadPool(spawner) => spawner.spawn(future, id), } } diff --git a/tokio/src/runtime/task/abort.rs b/tokio/src/runtime/task/abort.rs index 6ed7ff1b7f2..cad639ca0c8 100644 --- a/tokio/src/runtime/task/abort.rs +++ b/tokio/src/runtime/task/abort.rs @@ -1,4 +1,4 @@ -use crate::runtime::task::RawTask; +use crate::runtime::task::{Id, RawTask}; use std::fmt; use std::panic::{RefUnwindSafe, UnwindSafe}; @@ -21,11 +21,12 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] pub struct AbortHandle { raw: Option, + id: Id, } impl AbortHandle { - pub(super) fn new(raw: Option) -> Self { - Self { raw } + pub(super) fn new(raw: Option, id: Id) -> Self { + Self { raw, id } } /// Abort the task associated with the handle. @@ -47,6 +48,21 @@ impl AbortHandle { raw.remote_abort(); } } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + self.id.clone() + } } unsafe impl Send for AbortHandle {} @@ -57,7 +73,9 @@ impl RefUnwindSafe for AbortHandle {} impl fmt::Debug for AbortHandle { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("AbortHandle").finish() + fmt.debug_struct("AbortHandle") + .field("id", &self.id) + .finish() } } diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 776e8341f37..548c56da3d4 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -13,7 +13,7 @@ use crate::future::Future; use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::Schedule; +use crate::runtime::task::{Id, Schedule}; use crate::util::linked_list; use std::pin::Pin; @@ -49,6 +49,9 @@ pub(super) struct Core { /// Either the future or the output. pub(super) stage: CoreStage, + + /// The task's ID, used for populating `JoinError`s. + pub(super) task_id: Id, } /// Crate public as this is also needed by the pool. @@ -102,7 +105,7 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State) -> Box> { + pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box> { #[cfg(all(tokio_unstable, feature = "tracing"))] let id = future.id(); Box::new(Cell { @@ -120,6 +123,7 @@ impl Cell { stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), }, + task_id, }, trailer: Trailer { waker: UnsafeCell::new(None), diff --git a/tokio/src/runtime/task/error.rs b/tokio/src/runtime/task/error.rs index 1a8129b2b6f..22b688aa221 100644 --- a/tokio/src/runtime/task/error.rs +++ b/tokio/src/runtime/task/error.rs @@ -2,12 +2,13 @@ use std::any::Any; use std::fmt; use std::io; +use super::Id; use crate::util::SyncWrapper; - cfg_rt! { /// Task failed to execute to completion. pub struct JoinError { repr: Repr, + id: Id, } } @@ -17,15 +18,17 @@ enum Repr { } impl JoinError { - pub(crate) fn cancelled() -> JoinError { + pub(crate) fn cancelled(id: Id) -> JoinError { JoinError { repr: Repr::Cancelled, + id, } } - pub(crate) fn panic(err: Box) -> JoinError { + pub(crate) fn panic(id: Id, err: Box) -> JoinError { JoinError { repr: Repr::Panic(SyncWrapper::new(err)), + id, } } @@ -111,13 +114,28 @@ impl JoinError { _ => Err(self), } } + + /// Returns a [task ID] that identifies the task which errored relative to + /// other currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> Id { + self.id.clone() + } } impl fmt::Display for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "cancelled"), - Repr::Panic(_) => write!(fmt, "panic"), + Repr::Cancelled => write!(fmt, "task {} was cancelled", self.id), + Repr::Panic(_) => write!(fmt, "task {} panicked", self.id), } } } @@ -125,8 +143,8 @@ impl fmt::Display for JoinError { impl fmt::Debug for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), - Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + Repr::Cancelled => write!(fmt, "JoinError::Cancelled({:?})", self.id), + Repr::Panic(_) => write!(fmt, "JoinError::Panic({:?}, ...)", self.id), } } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 261dccea415..1d3ababfb17 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -100,7 +100,8 @@ where let header_ptr = self.header_ptr(); let waker_ref = waker_ref::(&header_ptr); let cx = Context::from_waker(&*waker_ref); - let res = poll_future(&self.core().stage, cx); + let core = self.core(); + let res = poll_future(&core.stage, core.task_id.clone(), cx); if res == Poll::Ready(()) { // The future completed. Move on to complete the task. @@ -114,14 +115,15 @@ where TransitionToIdle::Cancelled => { // The transition to idle failed because the task was // cancelled during the poll. - - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); PollFuture::Complete } } } TransitionToRunning::Cancelled => { - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); PollFuture::Complete } TransitionToRunning::Failed => PollFuture::Done, @@ -144,7 +146,8 @@ where // By transitioning the lifecycle to `Running`, we have permission to // drop the future. - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); self.complete(); } @@ -432,7 +435,7 @@ enum PollFuture { } /// Cancels the task and store the appropriate error in the stage field. -fn cancel_task(stage: &CoreStage) { +fn cancel_task(stage: &CoreStage, id: super::Id) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { stage.drop_future_or_output(); @@ -440,17 +443,17 @@ fn cancel_task(stage: &CoreStage) { match res { Ok(()) => { - stage.store_output(Err(JoinError::cancelled())); + stage.store_output(Err(JoinError::cancelled(id))); } Err(panic) => { - stage.store_output(Err(JoinError::panic(panic))); + stage.store_output(Err(JoinError::panic(id, panic))); } } } /// Polls the future. If the future completes, the output is written to the /// stage field. -fn poll_future(core: &CoreStage, cx: Context<'_>) -> Poll<()> { +fn poll_future(core: &CoreStage, id: super::Id, cx: Context<'_>) -> Poll<()> { // Poll the future. let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { struct Guard<'a, T: Future> { @@ -473,7 +476,7 @@ fn poll_future(core: &CoreStage, cx: Context<'_>) -> Poll<()> { let output = match output { Ok(Poll::Pending) => return Poll::Pending, Ok(Poll::Ready(output)) => Ok(output), - Err(panic) => Err(JoinError::panic(panic)), + Err(panic) => Err(JoinError::panic(id, panic)), }; // Catch and ignore panics if the future panics on drop. diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index b7846348e34..86580c84b59 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -1,4 +1,4 @@ -use crate::runtime::task::RawTask; +use crate::runtime::task::{Id, RawTask}; use std::fmt; use std::future::Future; @@ -144,6 +144,7 @@ cfg_rt! { /// [`JoinError`]: crate::task::JoinError pub struct JoinHandle { raw: Option, + id: Id, _p: PhantomData, } } @@ -155,9 +156,10 @@ impl UnwindSafe for JoinHandle {} impl RefUnwindSafe for JoinHandle {} impl JoinHandle { - pub(super) fn new(raw: RawTask) -> JoinHandle { + pub(super) fn new(raw: RawTask, id: Id) -> JoinHandle { JoinHandle { raw: Some(raw), + id, _p: PhantomData, } } @@ -218,7 +220,22 @@ impl JoinHandle { raw.ref_inc(); raw }); - super::AbortHandle::new(raw) + super::AbortHandle::new(raw, self.id.clone()) + } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + self.id.clone() } } @@ -280,6 +297,8 @@ where T: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("JoinHandle").finish() + fmt.debug_struct("JoinHandle") + .field("id", &self.id) + .finish() } } diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 7758f8db7aa..7a1dff0bbfc 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -84,13 +84,14 @@ impl OwnedTasks { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access @@ -187,13 +188,14 @@ impl LocalOwnedTasks { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index c7a85c5a173..316b4a49675 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -184,6 +184,27 @@ use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; +/// An opaque ID that uniquely identifies a task relative to all other currently +/// running tasks. +/// +/// # Notes +/// +/// - Task IDs are unique relative to other *currently running* tasks. When a +/// task completes, the same ID may be used for another task. +/// - Task IDs are *not* sequential, and do not indicate the order in which +/// tasks are spawned, what runtime a task is spawned on, or any other data. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well... +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub struct Id(u64); + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task { @@ -250,14 +271,15 @@ cfg_rt! { /// notification. fn new_task( task: T, - scheduler: S + scheduler: S, + id: Id, ) -> (Task, Notified, JoinHandle) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let raw = RawTask::new::(task, scheduler); + let raw = RawTask::new::(task, scheduler, id.clone()); let task = Task { raw, _p: PhantomData, @@ -266,7 +288,7 @@ cfg_rt! { raw, _p: PhantomData, }); - let join = JoinHandle::new(raw); + let join = JoinHandle::new(raw, id); (task, notified, join) } @@ -275,13 +297,13 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned(task: T, scheduler: S) -> (UnownedTask, JoinHandle) + pub(crate) fn unowned(task: T, scheduler: S, id: Id) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { - let (task, notified, join) = new_task(task, scheduler); + let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. // This is valid because an UnownedTask holds two ref-counts. @@ -450,3 +472,46 @@ unsafe impl linked_list::Link for Task { NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) } } + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Id { + // When 64-bit atomics are available, use a static `AtomicU64` counter to + // generate task IDs. + // + // Note(eliza): we _could_ just use `crate::loom::AtomicU64`, which switches + // between an atomic and mutex-based implementation here, rather than having + // two separate functions for targets with and without 64-bit atomics. + // However, because we can't use the mutex-based implementation in a static + // initializer directly, the 32-bit impl also has to use a `OnceCell`, and I + // thought it was nicer to avoid the `OnceCell` overhead on 64-bit + // platforms... + cfg_has_atomic_u64! { + pub(crate) fn next() -> Self { + use std::sync::atomic::{AtomicU64, Ordering::Relaxed}; + static NEXT_ID: AtomicU64 = AtomicU64::new(1); + Self(NEXT_ID.fetch_add(1, Relaxed)) + } + } + + cfg_not_has_atomic_u64! { + pub(crate) fn next() -> Self { + use once_cell::sync::Lazy; + use crate::loom::sync::Mutex; + + static NEXT_ID: Lazy> = Lazy::new(|| Mutex::new(1)); + let mut lock = NEXT_ID.lock(); + let id = *lock; + *lock += 1; + Self(id) + } + } + + pub(crate) fn as_u64(&self) -> u64 { + self.0 + } +} diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 95569f947e5..5555298a4d4 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -1,5 +1,5 @@ use crate::future::Future; -use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; +use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -52,12 +52,12 @@ pub(super) fn vtable() -> &'static Vtable { } impl RawTask { - pub(super) fn new(task: T, scheduler: S) -> RawTask + pub(super) fn new(task: T, scheduler: S, id: Id) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 4b49698a86a..08724d43ee4 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -2,7 +2,7 @@ use self::unowned_wrapper::unowned; mod unowned_wrapper { use crate::runtime::blocking::NoopSchedule; - use crate::runtime::task::{JoinHandle, Notified}; + use crate::runtime::task::{Id, JoinHandle, Notified}; #[cfg(all(tokio_unstable, feature = "tracing"))] pub(crate) fn unowned(task: T) -> (Notified, JoinHandle) @@ -13,7 +13,7 @@ mod unowned_wrapper { use tracing::Instrument; let span = tracing::trace_span!("test_span"); let task = task.instrument(span); - let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } @@ -23,7 +23,7 @@ mod unowned_wrapper { T: std::future::Future + Send + 'static, T::Output: Send + 'static, { - let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } } diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index 622f5661784..173e5b0b23f 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,5 +1,5 @@ use crate::runtime::blocking::NoopSchedule; -use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::task::{self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task}; use crate::util::TryLock; use std::collections::VecDeque; @@ -55,6 +55,7 @@ fn create_drop1() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(notified); handle.assert_not_dropped(); @@ -71,6 +72,7 @@ fn create_drop2() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(join); handle.assert_not_dropped(); @@ -87,6 +89,7 @@ fn drop_abort_handle1() { unreachable!() }, NoopSchedule, + Id::next(), ); let abort = join.abort_handle(); drop(join); @@ -106,6 +109,7 @@ fn drop_abort_handle2() { unreachable!() }, NoopSchedule, + Id::next(), ); let abort = join.abort_handle(); drop(notified); @@ -126,6 +130,7 @@ fn create_shutdown1() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(join); handle.assert_not_dropped(); @@ -142,6 +147,7 @@ fn create_shutdown2() { unreachable!() }, NoopSchedule, + Id::next(), ); handle.assert_not_dropped(); notified.shutdown(); @@ -151,7 +157,7 @@ fn create_shutdown2() { #[test] fn unowned_poll() { - let (task, _) = unowned(async {}, NoopSchedule); + let (task, _) = unowned(async {}, NoopSchedule, Id::next()); task.run(); } @@ -266,7 +272,7 @@ impl Runtime { T: 'static + Send + Future, T::Output: 'static + Send, { - let (handle, notified) = self.0.owned.bind(future, self.clone()); + let (handle, notified) = self.0.owned.bind(future, self.clone(), Id::next()); if let Some(notified) = notified { self.schedule(notified); diff --git a/tokio/src/runtime/thread_pool/mod.rs b/tokio/src/runtime/thread_pool/mod.rs index 76346c686e7..ef6b5775ca2 100644 --- a/tokio/src/runtime/thread_pool/mod.rs +++ b/tokio/src/runtime/thread_pool/mod.rs @@ -14,7 +14,7 @@ pub(crate) use worker::Launch; pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; -use crate::runtime::task::JoinHandle; +use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{Callback, Driver, HandleInner}; use std::fmt; @@ -98,12 +98,12 @@ impl Drop for ThreadPool { impl Spawner { /// Spawns a future onto the thread pool - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: task::Id) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - worker::Shared::bind_new_task(&self.shared, future) + worker::Shared::bind_new_task(&self.shared, future, id) } pub(crate) fn shutdown(&mut self) { diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index 9b456570d6e..3d58767f308 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -723,12 +723,16 @@ impl Shared { &self.handle_inner } - pub(super) fn bind_new_task(me: &Arc, future: T) -> JoinHandle + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: crate::runtime::task::Id, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.owned.bind(future, me.clone()); + let (handle, notified) = me.owned.bind(future, me.clone(), id); if let Some(notified) = notified { me.schedule(notified, false); diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index 996ae1a9219..d036fcc3cda 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use crate::runtime::Handle; -use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet}; +use crate::task::{AbortHandle, Id, JoinError, JoinHandle, LocalSet}; use crate::util::IdleNotifiedSet; /// A collection of tasks spawned on a Tokio runtime. @@ -155,6 +155,24 @@ impl JoinSet { /// statement and some other branch completes first, it is guaranteed that no tasks were /// removed from this `JoinSet`. pub async fn join_one(&mut self) -> Result, JoinError> { + crate::future::poll_fn(|cx| self.poll_join_one(cx)) + .await + .map(|opt| opt.map(|(_, res)| res)) + } + + /// Waits until one of the tasks in the set completes and returns its + /// output, along with the [task ID] of the completed task. + /// + /// Returns `None` if the set is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one_with_id` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinSet`. + /// + /// [task ID]: crate::task::Id + pub async fn join_one_with_id(&mut self) -> Result, JoinError> { crate::future::poll_fn(|cx| self.poll_join_one(cx)).await } @@ -191,8 +209,8 @@ impl JoinSet { /// Polls for one of the tasks in the set to complete. /// - /// If this returns `Poll::Ready(Ok(Some(_)))` or `Poll::Ready(Err(_))`, then the task that - /// completed is removed from the set. + /// If this returns `Poll::Ready(Some((_, Ok(_))))` or `Poll::Ready(Some((_, + /// Err(_)))`, then the task that completed is removed from the set. /// /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to @@ -205,17 +223,19 @@ impl JoinSet { /// /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is /// available right now. - /// * `Poll::Ready(Ok(Some(value)))` if one of the tasks in this `JoinSet` has completed. The - /// `value` is the return value of one of the tasks that completed. + /// * `Poll::Ready(Ok(Some((id, value)))` if one of the tasks in this `JoinSet` has completed. The + /// `value` is the return value of one of the tasks that completed, while + /// `id` is the [task ID] of that task. /// * `Poll::Ready(Err(err))` if one of the tasks in this `JoinSet` has panicked or been - /// aborted. + /// aborted. The `err` is the `JoinError` from the panicked/aborted task. /// * `Poll::Ready(Ok(None))` if the `JoinSet` is empty. /// /// Note that this method may return `Poll::Pending` even if one of the tasks has completed. /// This can happen if the [coop budget] is reached. /// /// [coop budget]: crate::task#cooperative-scheduling - fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll, JoinError>> { + /// [task ID]: crate::task::Id + fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll, JoinError>> { // The call to `pop_notified` moves the entry to the `idle` list. It is moved back to // the `notified` list if the waker is notified in the `poll` call below. let mut entry = match self.inner.pop_notified(cx.waker()) { @@ -233,7 +253,10 @@ impl JoinSet { let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx)); if let Poll::Ready(res) = res { - entry.remove(); + let entry = entry.remove(); + // If the task succeeded, add the task ID to the output. Otherwise, the + // `JoinError` will already have the task's ID. + let res = res.map(|output| (entry.id(), output)); Poll::Ready(Some(res).transpose()) } else { // A JoinHandle generally won't emit a wakeup without being ready unless diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 2dbd9706047..32e376872f4 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -301,12 +301,13 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - let future = crate::util::trace::task(future, "local", name); + let id = crate::runtime::task::Id::next(); + let future = crate::util::trace::task(future, "local", name, id.as_u64()); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); - let (handle, notified) = cx.owned.bind(future, cx.shared.clone()); + let (handle, notified) = cx.owned.bind(future, cx.shared.clone(), id); if let Some(notified) = notified { cx.shared.schedule(notified); @@ -385,9 +386,13 @@ impl LocalSet { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local", None); + let id = crate::runtime::task::Id::next(); + let future = crate::util::trace::task(future, "local", None, id.as_u64()); - let (handle, notified) = self.context.owned.bind(future, self.context.shared.clone()); + let (handle, notified) = self + .context + .owned + .bind(future, self.context.shared.clone(), id); if let Some(notified) = notified { self.context.shared.schedule(notified); diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index bf5530b8339..cebc269bb40 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -303,7 +303,7 @@ cfg_rt! { cfg_unstable! { mod join_set; pub use join_set::JoinSet; - pub use crate::runtime::task::AbortHandle; + pub use crate::runtime::task::{Id, AbortHandle}; } cfg_trace! { diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index a9d736674c0..5a60f9d66e6 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -142,8 +142,10 @@ cfg_rt! { T: Future + Send + 'static, T::Output: Send + 'static, { - let spawn_handle = crate::runtime::context::spawn_handle().expect(CONTEXT_MISSING_ERROR); - let task = crate::util::trace::task(future, "task", name); - spawn_handle.spawn(task) + use crate::runtime::{task, context}; + let id = task::Id::next(); + let spawn_handle = context::spawn_handle().expect(CONTEXT_MISSING_ERROR); + let task = crate::util::trace::task(future, "task", name, id.as_u64()); + spawn_handle.spawn(task, id) } } diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs index 6080e2358ae..76e8a6cbf55 100644 --- a/tokio/src/util/trace.rs +++ b/tokio/src/util/trace.rs @@ -10,7 +10,7 @@ cfg_trace! { #[inline] #[track_caller] - pub(crate) fn task(task: F, kind: &'static str, name: Option<&str>) -> Instrumented { + pub(crate) fn task(task: F, kind: &'static str, name: Option<&str>, id: u64) -> Instrumented { use tracing::instrument::Instrument; let location = std::panic::Location::caller(); let span = tracing::trace_span!( @@ -18,6 +18,7 @@ cfg_trace! { "runtime.spawn", %kind, task.name = %name.unwrap_or_default(), + task.id = id, loc.file = location.file(), loc.line = location.line(), loc.col = location.column(), @@ -91,7 +92,7 @@ cfg_time! { cfg_not_trace! { cfg_rt! { #[inline] - pub(crate) fn task(task: F, _: &'static str, _name: Option<&str>) -> F { + pub(crate) fn task(task: F, _: &'static str, _name: Option<&str>, _: u64) -> F { // nop task }