Skip to content

Commit

Permalink
task: add task IDs (#4630)
Browse files Browse the repository at this point in the history
## Motivation

PR #4538 adds a prototype implementation of a `JoinMap` API in
`tokio::task`. In [this comment][1] on that PR, @carllerche pointed out
that a much simpler `JoinMap` type could be implemented outside of
`tokio` (either in `tokio-util` or in user code) if we just modified
`JoinSet` to return a task ID type when spawning new tasks, and when
tasks complete. This seems like a better approach for the following
reasons:

* A `JoinMap`-like type need not become a permanent part of `tokio`'s
  stable API
* Task IDs seem like something that could be generally useful outside of
  a `JoinMap` implementation

## Solution

This branch adds a `tokio::task::Id` type that uniquely identifies a
task relative to all other spawned tasks. Task IDs are assigned
sequentially based on an atomic `usize` counter of spawned tasks.

In addition, I modified `JoinSet` to add a `join_with_id` method that
behaves identically to `join_one` but also returns an ID. This can be
used to implement a `JoinMap` type.

Note that because `join_with_id` must return a task ID regardless of
whether the task completes successfully or returns a `JoinError`, I've
also changed `JoinError` to carry the ID of the task that errored, and 
added a `JoinError::id` method for accessing it. Alternatively, we could
have done one of the following:

* have `join_with_id` return `Option<(Id, Result<T, JoinError>)>`, which
  would be inconsistent with the return type of `join_one` (which we've
  [already bikeshedded over once][2]...)
* have `join_with_id` return `Result<Option<(Id, T)>, (Id, JoinError)>>`,
  which just feels gross.

I thought adding the task ID to `JoinError` was the nicest option, and
is potentially useful for other stuff as well, so it's probably a good API to
have anyway.

[1]: #4538 (comment)
[2]: #4335 (comment)

Closes #4538

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
  • Loading branch information
hawkw committed Apr 25, 2022
1 parent b4d82c3 commit 1d3f123
Show file tree
Hide file tree
Showing 22 changed files with 252 additions and 78 deletions.
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Expand Up @@ -65,7 +65,7 @@ process = [
"winapi/threadpoollegacyapiset",
]
# Includes basic task execution capabilities
rt = []
rt = ["once_cell"]
rt-multi-thread = [
"num_cpus",
"rt",
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/basic_scheduler.rs
Expand Up @@ -370,12 +370,12 @@ impl Context {

impl Spawner {
/// Spawns a future onto the basic scheduler
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
pub(crate) fn spawn<F>(&self, future: F, id: super::task::Id) -> JoinHandle<F::Output>
where
F: crate::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let (handle, notified) = self.shared.owned.bind(future, self.shared.clone());
let (handle, notified) = self.shared.owned.bind(future, self.shared.clone(), id);

if let Some(notified) = notified {
self.shared.schedule(notified);
Expand Down
13 changes: 8 additions & 5 deletions tokio/src/runtime/handle.rs
Expand Up @@ -175,9 +175,10 @@ impl Handle {
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let id = crate::runtime::task::Id::next();
#[cfg(all(tokio_unstable, feature = "tracing"))]
let future = crate::util::trace::task(future, "task", None);
self.spawner.spawn(future)
let future = crate::util::trace::task(future, "task", None, id.as_u64());
self.spawner.spawn(future, id)
}

/// Runs the provided function on an executor dedicated to blocking.
Expand Down Expand Up @@ -285,7 +286,8 @@ impl Handle {
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let future = crate::util::trace::task(future, "block_on", None);
let future =
crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64());

// Enter the **runtime** context. This configures spawning, the current I/O driver, ...
let _rt_enter = self.enter();
Expand Down Expand Up @@ -388,7 +390,7 @@ impl HandleInner {
R: Send + 'static,
{
let fut = BlockingTask::new(func);

let id = super::task::Id::next();
#[cfg(all(tokio_unstable, feature = "tracing"))]
let fut = {
use tracing::Instrument;
Expand All @@ -398,6 +400,7 @@ impl HandleInner {
"runtime.spawn",
kind = %"blocking",
task.name = %name.unwrap_or_default(),
task.id = id.as_u64(),
"fn" = %std::any::type_name::<F>(),
spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()),
);
Expand All @@ -407,7 +410,7 @@ impl HandleInner {
#[cfg(not(all(tokio_unstable, feature = "tracing")))]
let _ = name;

let (task, handle) = task::unowned(fut, NoopSchedule);
let (task, handle) = task::unowned(fut, NoopSchedule, id);
let spawned = self
.blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), rt);
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/mod.rs
Expand Up @@ -467,7 +467,7 @@ cfg_rt! {
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let future = crate::util::trace::task(future, "block_on", None);
let future = crate::util::trace::task(future, "block_on", None, task::Id::next().as_u64());

let _enter = self.enter();

Expand Down
7 changes: 4 additions & 3 deletions tokio/src/runtime/spawner.rs
@@ -1,4 +1,5 @@
use crate::future::Future;
use crate::runtime::task::Id;
use crate::runtime::{basic_scheduler, HandleInner};
use crate::task::JoinHandle;

Expand All @@ -23,15 +24,15 @@ impl Spawner {
}
}

pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
pub(crate) fn spawn<F>(&self, future: F, id: Id) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self {
Spawner::Basic(spawner) => spawner.spawn(future),
Spawner::Basic(spawner) => spawner.spawn(future, id),
#[cfg(feature = "rt-multi-thread")]
Spawner::ThreadPool(spawner) => spawner.spawn(future),
Spawner::ThreadPool(spawner) => spawner.spawn(future, id),
}
}

Expand Down
26 changes: 22 additions & 4 deletions tokio/src/runtime/task/abort.rs
@@ -1,4 +1,4 @@
use crate::runtime::task::RawTask;
use crate::runtime::task::{Id, RawTask};
use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};

Expand All @@ -21,11 +21,12 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
pub struct AbortHandle {
raw: Option<RawTask>,
id: Id,
}

impl AbortHandle {
pub(super) fn new(raw: Option<RawTask>) -> Self {
Self { raw }
pub(super) fn new(raw: Option<RawTask>, id: Id) -> Self {
Self { raw, id }
}

/// Abort the task associated with the handle.
Expand All @@ -47,6 +48,21 @@ impl AbortHandle {
raw.remote_abort();
}
}

/// Returns a [task ID] that uniquely identifies this task relative to other
/// currently spawned tasks.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
/// features][unstable] for details.
///
/// [task ID]: crate::task::Id
/// [unstable]: crate#unstable-features
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
self.id.clone()
}
}

unsafe impl Send for AbortHandle {}
Expand All @@ -57,7 +73,9 @@ impl RefUnwindSafe for AbortHandle {}

impl fmt::Debug for AbortHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AbortHandle").finish()
fmt.debug_struct("AbortHandle")
.field("id", &self.id)
.finish()
}
}

Expand Down
8 changes: 6 additions & 2 deletions tokio/src/runtime/task/core.rs
Expand Up @@ -13,7 +13,7 @@ use crate::future::Future;
use crate::loom::cell::UnsafeCell;
use crate::runtime::task::raw::{self, Vtable};
use crate::runtime::task::state::State;
use crate::runtime::task::Schedule;
use crate::runtime::task::{Id, Schedule};
use crate::util::linked_list;

use std::pin::Pin;
Expand Down Expand Up @@ -49,6 +49,9 @@ pub(super) struct Core<T: Future, S> {

/// Either the future or the output.
pub(super) stage: CoreStage<T>,

/// The task's ID, used for populating `JoinError`s.
pub(super) task_id: Id,
}

/// Crate public as this is also needed by the pool.
Expand Down Expand Up @@ -102,7 +105,7 @@ pub(super) enum Stage<T: Future> {
impl<T: Future, S: Schedule> Cell<T, S> {
/// Allocates a new task cell, containing the header, trailer, and core
/// structures.
pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> {
pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let id = future.id();
Box::new(Cell {
Expand All @@ -120,6 +123,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
stage: CoreStage {
stage: UnsafeCell::new(Stage::Running(future)),
},
task_id,
},
trailer: Trailer {
waker: UnsafeCell::new(None),
Expand Down
32 changes: 25 additions & 7 deletions tokio/src/runtime/task/error.rs
Expand Up @@ -2,12 +2,13 @@ use std::any::Any;
use std::fmt;
use std::io;

use super::Id;
use crate::util::SyncWrapper;

cfg_rt! {
/// Task failed to execute to completion.
pub struct JoinError {
repr: Repr,
id: Id,
}
}

Expand All @@ -17,15 +18,17 @@ enum Repr {
}

impl JoinError {
pub(crate) fn cancelled() -> JoinError {
pub(crate) fn cancelled(id: Id) -> JoinError {
JoinError {
repr: Repr::Cancelled,
id,
}
}

pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError {
pub(crate) fn panic(id: Id, err: Box<dyn Any + Send + 'static>) -> JoinError {
JoinError {
repr: Repr::Panic(SyncWrapper::new(err)),
id,
}
}

Expand Down Expand Up @@ -111,22 +114,37 @@ impl JoinError {
_ => Err(self),
}
}

/// Returns a [task ID] that identifies the task which errored relative to
/// other currently spawned tasks.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
/// features][unstable] for details.
///
/// [task ID]: crate::task::Id
/// [unstable]: crate#unstable-features
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> Id {
self.id.clone()
}
}

impl fmt::Display for JoinError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.repr {
Repr::Cancelled => write!(fmt, "cancelled"),
Repr::Panic(_) => write!(fmt, "panic"),
Repr::Cancelled => write!(fmt, "task {} was cancelled", self.id),
Repr::Panic(_) => write!(fmt, "task {} panicked", self.id),
}
}
}

impl fmt::Debug for JoinError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.repr {
Repr::Cancelled => write!(fmt, "JoinError::Cancelled"),
Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"),
Repr::Cancelled => write!(fmt, "JoinError::Cancelled({:?})", self.id),
Repr::Panic(_) => write!(fmt, "JoinError::Panic({:?}, ...)", self.id),
}
}
}
Expand Down
23 changes: 13 additions & 10 deletions tokio/src/runtime/task/harness.rs
Expand Up @@ -100,7 +100,8 @@ where
let header_ptr = self.header_ptr();
let waker_ref = waker_ref::<T, S>(&header_ptr);
let cx = Context::from_waker(&*waker_ref);
let res = poll_future(&self.core().stage, cx);
let core = self.core();
let res = poll_future(&core.stage, core.task_id.clone(), cx);

if res == Poll::Ready(()) {
// The future completed. Move on to complete the task.
Expand All @@ -114,14 +115,15 @@ where
TransitionToIdle::Cancelled => {
// The transition to idle failed because the task was
// cancelled during the poll.

cancel_task(&self.core().stage);
let core = self.core();
cancel_task(&core.stage, core.task_id.clone());
PollFuture::Complete
}
}
}
TransitionToRunning::Cancelled => {
cancel_task(&self.core().stage);
let core = self.core();
cancel_task(&core.stage, core.task_id.clone());
PollFuture::Complete
}
TransitionToRunning::Failed => PollFuture::Done,
Expand All @@ -144,7 +146,8 @@ where

// By transitioning the lifecycle to `Running`, we have permission to
// drop the future.
cancel_task(&self.core().stage);
let core = self.core();
cancel_task(&core.stage, core.task_id.clone());
self.complete();
}

Expand Down Expand Up @@ -432,25 +435,25 @@ enum PollFuture {
}

/// Cancels the task and store the appropriate error in the stage field.
fn cancel_task<T: Future>(stage: &CoreStage<T>) {
fn cancel_task<T: Future>(stage: &CoreStage<T>, id: super::Id) {
// Drop the future from a panic guard.
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
stage.drop_future_or_output();
}));

match res {
Ok(()) => {
stage.store_output(Err(JoinError::cancelled()));
stage.store_output(Err(JoinError::cancelled(id)));
}
Err(panic) => {
stage.store_output(Err(JoinError::panic(panic)));
stage.store_output(Err(JoinError::panic(id, panic)));
}
}
}

/// Polls the future. If the future completes, the output is written to the
/// stage field.
fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> {
fn poll_future<T: Future>(core: &CoreStage<T>, id: super::Id, cx: Context<'_>) -> Poll<()> {
// Poll the future.
let output = panic::catch_unwind(panic::AssertUnwindSafe(|| {
struct Guard<'a, T: Future> {
Expand All @@ -473,7 +476,7 @@ fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> {
let output = match output {
Ok(Poll::Pending) => return Poll::Pending,
Ok(Poll::Ready(output)) => Ok(output),
Err(panic) => Err(JoinError::panic(panic)),
Err(panic) => Err(JoinError::panic(id, panic)),
};

// Catch and ignore panics if the future panics on drop.
Expand Down

0 comments on commit 1d3f123

Please sign in to comment.