Skip to content

Commit

Permalink
rt: add a method to retrieve task id
Browse files Browse the repository at this point in the history
  • Loading branch information
agayev committed Nov 6, 2022
1 parent f464360 commit be9c283
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 3 deletions.
15 changes: 15 additions & 0 deletions tokio/src/runtime/context.rs
@@ -1,4 +1,5 @@
use crate::runtime::coop;
use crate::runtime::task::Id;

use std::cell::Cell;

Expand All @@ -17,6 +18,7 @@ struct Context {
/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
scheduler: RefCell<Option<scheduler::Handle>>,
current_task_id: Cell<Option<Id>>,

#[cfg(any(feature = "rt", feature = "macros"))]
rng: FastRand,
Expand All @@ -31,6 +33,7 @@ tokio_thread_local! {
Context {
#[cfg(feature = "rt")]
scheduler: RefCell::new(None),
current_task_id: Cell::new(None),

#[cfg(any(feature = "rt", feature = "macros"))]
rng: FastRand::new(RngSeed::new()),
Expand Down Expand Up @@ -85,6 +88,18 @@ cfg_rt! {

pub(crate) struct DisallowBlockInPlaceGuard(bool);

pub(crate) fn set_current_task_id(id: Option<Id>) {
CONTEXT.with(|ctx| ctx.current_task_id.replace(id));
}

#[track_caller]
pub(crate) fn current_task_id() -> Id {
match CONTEXT.try_with(|ctx| ctx.current_task_id.get()) {
Ok(Some(id)) => id,
_ => panic!("can't get a task id when not inside a task"),
}
}

pub(crate) fn try_current() -> Result<scheduler::Handle, TryCurrentError> {
match CONTEXT.try_with(|ctx| ctx.scheduler.borrow().clone()) {
Ok(Some(handle)) => Ok(handle),
Expand Down
20 changes: 19 additions & 1 deletion tokio/src/runtime/task/harness.rs
@@ -1,8 +1,9 @@
use crate::future::Future;
use crate::runtime::context;
use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer};
use crate::runtime::task::state::Snapshot;
use crate::runtime::task::waker::waker_ref;
use crate::runtime::task::{JoinError, Notified, Schedule, Task};
use crate::runtime::task::{Id, JoinError, Notified, Schedule, Task};

use std::mem;
use std::mem::ManuallyDrop;
Expand Down Expand Up @@ -439,10 +440,26 @@ enum PollFuture {
Dealloc,
}

/// Guard that sets and clears the task id in the context during task execution
/// and cancellation.
struct TaskIdGuard {}
impl TaskIdGuard {
fn new(id: Id) -> Self {
context::set_current_task_id(Some(id));
TaskIdGuard {}
}
}
impl Drop for TaskIdGuard {
fn drop(&mut self) {
context::set_current_task_id(None);
}
}

/// Cancels the task and store the appropriate error in the stage field.
fn cancel_task<T: Future>(stage: &CoreStage<T>, id: super::Id) {
// Drop the future from a panic guard.
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let _task_id_guard = TaskIdGuard::new(id);
stage.drop_future_or_output();
}));

Expand Down Expand Up @@ -476,6 +493,7 @@ fn poll_future<T: Future, S: Schedule>(
self.core.drop_future_or_output();
}
}
let _task_id_guard = TaskIdGuard::new(id);
let guard = Guard { core };
let res = guard.core.poll(cx);
mem::forget(guard);
Expand Down
17 changes: 15 additions & 2 deletions tokio/src/runtime/task/mod.rs
Expand Up @@ -201,10 +201,23 @@ use std::{fmt, mem};
/// [unstable]: crate#unstable-features
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well...
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct Id(u64);

/// Returns the `Id` of the task.
///
/// # Panics
///
/// This function panics if called from outside a task or if called from a
/// future passed to `block_on` call.
///
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[track_caller]
pub fn id() -> Id {
use crate::runtime::context;
context::current_task_id()
}

/// An owned handle to the task, tracked by ref count.
#[repr(transparent)]
pub(crate) struct Task<S: 'static> {
Expand Down
1 change: 1 addition & 0 deletions tokio/src/task/mod.rs
Expand Up @@ -319,6 +319,7 @@ cfg_rt! {

cfg_unstable! {
pub use crate::runtime::task::Id;
pub use crate::runtime::task::id;
}

cfg_trace! {
Expand Down
61 changes: 61 additions & 0 deletions tokio/tests/task_local.rs
Expand Up @@ -116,3 +116,64 @@ async fn task_local_available_on_completion_drop() {
assert_eq!(rx.await.unwrap(), 42);
h.await.unwrap();
}

#[tokio::test(flavor = "current_thread")]
async fn task_id() {
use tokio::task;

let handle = tokio::spawn(async { println!("task id: {}", task::id()) });

handle.await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_id_collision_multi_thread() {
use tokio::task;

let handle1 = tokio::spawn(async { task::id() });
let handle2 = tokio::spawn(async { task::id() });

let (id1, id2) = tokio::join!(handle1, handle2);
assert_ne!(id1.unwrap(), id2.unwrap());
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_collision_current_thread() {
use tokio::task;

let handle1 = tokio::spawn(async { task::id() });
let handle2 = tokio::spawn(async { task::id() });

let (id1, id2) = tokio::join!(handle1, handle2);
assert_ne!(id1.unwrap(), id2.unwrap());
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_ids_match_current_thread() {
use tokio::{sync::oneshot, task};

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(async {
let id = rx.await.unwrap();
assert_eq!(id, task::id());
});
tx.send(handle.id()).unwrap();
handle.await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_ids_match_multi_thread() {
use tokio::{sync::oneshot, task};

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(async {
let id = rx.await.unwrap();
assert_eq!(id, task::id());
});
tx.send(handle.id()).unwrap();
handle.await.unwrap();
}
13 changes: 13 additions & 0 deletions tokio/tests/task_panic.rs
Expand Up @@ -120,3 +120,16 @@ fn local_key_get_panic_caller() -> Result<(), Box<dyn Error>> {

Ok(())
}

#[cfg(tokio_unstable)]
#[test]
fn task_id_handle_panic_caller() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let _ = task::id();
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

0 comments on commit be9c283

Please sign in to comment.