From 6d190b7845c5ed41cdbd981e60ad5bbcdeae9a85 Mon Sep 17 00:00:00 2001 From: Abutalib Aghayev Date: Wed, 9 Nov 2022 15:42:10 -0500 Subject: [PATCH] addressed issues raised by @hawkw --- tokio/src/runtime/context.rs | 7 ++----- tokio/src/runtime/task/core.rs | 2 +- tokio/src/runtime/task/harness.rs | 4 ++-- tokio/src/runtime/task/mod.rs | 11 ++++++++++- tokio/src/task/mod.rs | 1 + tokio/tests/task_local.rs | 24 ++++++++++++++++++++++-- 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 14951027ee5..3c1fa530a61 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -109,11 +109,8 @@ cfg_rt! { } #[track_caller] - pub(crate) fn current_task_id() -> Id { - match CONTEXT.try_with(|ctx| ctx.current_task_id.get()) { - Ok(Some(id)) => id, - _ => panic!("can't get a task id when not inside a task"), - } + pub(crate) fn current_task_id() -> Option { + CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None) } pub(crate) fn try_current() -> Result { diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 1d0fcbc687e..39fbdb03882 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -201,7 +201,7 @@ impl CoreStage { // Safety: the caller ensures mutual exclusion to the field. unsafe { let _task_id_guard = self.stage.with_mut(|ptr| match &*ptr { - Stage::Finished(Ok(_), id) => Some(TaskIdGuard::new(*id)), + Stage::Finished(Ok(_), id) => Some(TaskIdGuard::enter(*id)), _ => None, }); self.set_stage(Stage::Consumed); diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 5c386eb616f..126ec10e359 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -460,7 +460,7 @@ enum PollFuture { fn cancel_task(stage: &CoreStage, id: super::Id) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { - let _task_id_guard = TaskIdGuard::new(id); + let _task_id_guard = TaskIdGuard::enter(id); stage.drop_future_or_output(); })); @@ -483,7 +483,7 @@ fn poll_future( cx: Context<'_>, ) -> Poll<()> { // Poll the future. - let _task_id_guard = TaskIdGuard::new(id); + let _task_id_guard = TaskIdGuard::enter(id); let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { struct Guard<'a, T: Future> { core: &'a CoreStage, diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index cb9e61fb1eb..522060994ab 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -216,6 +216,15 @@ use crate::runtime::context; #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] #[track_caller] pub fn id() -> Id { + use crate::runtime::context; + context::current_task_id().expect("Can't get a task id when not inside a task") +} + +/// Returns the `Id` of the task if called from inside the task, or None otherwise. +/// +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn try_id() -> Option { use crate::runtime::context; context::current_task_id() } @@ -224,7 +233,7 @@ pub fn id() -> Id { /// and cancellation. pub(crate) struct TaskIdGuard {} impl TaskIdGuard { - fn new(id: Id) -> Self { + fn enter(id: Id) -> Self { context::set_current_task_id(Some(id)); TaskIdGuard {} } diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index a8d69060706..978525406d4 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -320,6 +320,7 @@ cfg_rt! { cfg_unstable! { pub use crate::runtime::task::Id; pub use crate::runtime::task::id; + pub use crate::runtime::task::try_id; } cfg_trace! { diff --git a/tokio/tests/task_local.rs b/tokio/tests/task_local.rs index d87f1760bf2..454dc7000dc 100644 --- a/tokio/tests/task_local.rs +++ b/tokio/tests/task_local.rs @@ -122,7 +122,7 @@ async fn task_local_available_on_completion_drop() { async fn task_id_spawn() { use tokio::task; - task::spawn_blocking(|| println!("task id: {}", task::id())) + tokio::spawn(async { println!("task id: {}", task::id()) }) .await .unwrap(); } @@ -131,7 +131,7 @@ async fn task_id_spawn() { async fn task_id_spawn_blocking() { use tokio::task; - tokio::spawn(async { println!("task id: {}", task::id()) }) + task::spawn_blocking(|| println!("task id: {}", task::id())) .await .unwrap(); } @@ -301,3 +301,23 @@ async fn task_id_output_destructor_handle_dropped_after_completion() { rx.await.unwrap(); drop(handle); } + +#[cfg(tokio_unstable)] +#[test] +fn task_try_id_outside_task() { + use tokio::task; + + assert_eq!(None, task::try_id()); +} + +#[cfg(tokio_unstable)] +#[test] +fn task_try_id_inside_block_on() { + use tokio::runtime::Runtime; + use tokio::task; + + let rt = Runtime::new().unwrap(); + rt.block_on(async { + assert_eq!(None, task::try_id()); + }); +}