From 9a67b14646c667c4940957a6e6a60971503fc5f2 Mon Sep 17 00:00:00 2001 From: Abutalib Aghayev Date: Thu, 10 Nov 2022 11:58:20 -0500 Subject: [PATCH] rebase and rewrite --- tokio/src/runtime/context.rs | 12 ++ tokio/src/runtime/task/core.rs | 21 ++++ tokio/src/runtime/task/mod.rs | 26 ++++- tokio/src/task/mod.rs | 2 + tokio/tests/task_local.rs | 204 +++++++++++++++++++++++++++++++++ tokio/tests/task_local_set.rs | 39 +++++++ tokio/tests/task_panic.rs | 31 +++++ 7 files changed, 333 insertions(+), 2 deletions(-) diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 6de657481e0..5f6f99e120f 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,4 +1,5 @@ use crate::runtime::coop; +use crate::runtime::task::Id; use std::cell::Cell; @@ -17,6 +18,7 @@ struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] handle: RefCell>, + current_task_id: Cell>, /// Tracks if the current thread is currently driving a runtime. /// Note, that if this is set to "entered", the current scheduler @@ -41,6 +43,7 @@ tokio_thread_local! { /// accessing drivers, etc... #[cfg(feature = "rt")] handle: RefCell::new(None), + current_task_id: Cell::new(None), /// Tracks if the current thread is currently driving a runtime. /// Note, that if this is set to "entered", the current scheduler @@ -107,6 +110,15 @@ cfg_rt! { pub(crate) struct DisallowBlockInPlaceGuard(bool); + pub(crate) fn set_current_task_id(id: Option) -> Option { + CONTEXT.try_with(|ctx| ctx.current_task_id.replace(id)).unwrap_or(None) + } + + #[track_caller] + pub(crate) fn current_task_id() -> Option { + CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None) + } + pub(crate) fn try_current() -> Result { match CONTEXT.try_with(|ctx| ctx.handle.borrow().clone()) { Ok(Some(handle)) => Ok(handle), diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 3e07d7c97fd..69093894980 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -11,6 +11,7 @@ use crate::future::Future; use crate::loom::cell::UnsafeCell; +use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; use crate::runtime::task::{Id, Schedule}; @@ -157,6 +158,24 @@ impl CoreStage { } } +/// Set and clear the task id in the context when the future is executed or +/// dropped, or when the output produced by the future is dropped. +pub(crate) struct TaskIdGuard { + parent_task_id: Option, +} +impl TaskIdGuard { + fn enter(id: Id) -> Self { + TaskIdGuard { + parent_task_id: context::set_current_task_id(Some(id)), + } + } +} +impl Drop for TaskIdGuard { + fn drop(&mut self) { + context::set_current_task_id(self.parent_task_id); + } +} + impl Core { /// Polls the future. /// @@ -183,6 +202,7 @@ impl Core { // Safety: The caller ensures the future is pinned. let future = unsafe { Pin::new_unchecked(future) }; + let _guard = TaskIdGuard::enter(self.task_id); future.poll(&mut cx) }) }; @@ -202,6 +222,7 @@ impl Core { pub(super) fn drop_future_or_output(&self) { // Safety: the caller ensures mutual exclusion to the field. unsafe { + let _guard = TaskIdGuard::enter(self.task_id); self.set_stage(Stage::Consumed); } } diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 3d5b1cbf373..e90e01819af 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -201,10 +201,32 @@ use std::{fmt, mem}; /// [unstable]: crate#unstable-features #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] -// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well... -#[derive(Clone, Debug, Hash, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] pub struct Id(u64); +/// Returns the `Id` of the task. +/// +/// # Panics +/// +/// This function panics if called from outside a task or if called from a +/// future passed to `block_on` call. +/// +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn id() -> Id { + use crate::runtime::context; + context::current_task_id().expect("Can't get a task id when not inside a task") +} + +/// Returns the `Id` of the task if called from inside the task, or None otherwise. +/// +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn try_id() -> Option { + use crate::runtime::context; + context::current_task_id() +} + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task { diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f1683f7e07f..978525406d4 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -319,6 +319,8 @@ cfg_rt! { cfg_unstable! { pub use crate::runtime::task::Id; + pub use crate::runtime::task::id; + pub use crate::runtime::task::try_id; } cfg_trace! { diff --git a/tokio/tests/task_local.rs b/tokio/tests/task_local.rs index 949a40c2a4d..3b421caddf8 100644 --- a/tokio/tests/task_local.rs +++ b/tokio/tests/task_local.rs @@ -117,3 +117,207 @@ async fn task_local_available_on_completion_drop() { assert_eq!(rx.await.unwrap(), 42); h.await.unwrap(); } + +#[tokio::test(flavor = "current_thread")] +async fn task_id_spawn() { + use tokio::task; + + tokio::spawn(async { println!("task id: {}", task::id()) }) + .await + .unwrap(); +} + +#[tokio::test(flavor = "current_thread")] +async fn task_id_spawn_blocking() { + use tokio::task; + + task::spawn_blocking(|| println!("task id: {}", task::id())) + .await + .unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_collision_current_thread() { + use tokio::task; + + let handle1 = tokio::spawn(async { task::id() }); + let handle2 = tokio::spawn(async { task::id() }); + + let (id1, id2) = tokio::join!(handle1, handle2); + assert_ne!(id1.unwrap(), id2.unwrap()); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_id_collision_multi_thread() { + use tokio::task; + + let handle1 = tokio::spawn(async { task::id() }); + let handle2 = tokio::spawn(async { task::id() }); + + let (id1, id2) = tokio::join!(handle1, handle2); + assert_ne!(id1.unwrap(), id2.unwrap()); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_ids_match_current_thread() { + use tokio::{sync::oneshot, task}; + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async { + let id = rx.await.unwrap(); + assert_eq!(id, task::id()); + }); + tx.send(handle.id()).unwrap(); + handle.await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_ids_match_multi_thread() { + use tokio::{sync::oneshot, task}; + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async { + let id = rx.await.unwrap(); + assert_eq!(id, task::id()); + }); + tx.send(handle.id()).unwrap(); + handle.await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_id_future_destructor_completion() { + use tokio::task; + + struct MyFuture; + impl Future for MyFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + tokio::spawn(MyFuture).await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_id_future_destructor_abort() { + use tokio::task; + + struct MyFuture; + impl Future for MyFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Pending + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + tokio::spawn(MyFuture).abort(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_output_destructor_handle_dropped_before_completion() { + use tokio::task; + + struct MyOutput; + impl Drop for MyOutput { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + struct MyFuture { + tx: Option>, + } + impl Future for MyFuture { + type Output = MyOutput; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let _ = self.tx.take().unwrap().send(()); + Poll::Ready(MyOutput) + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(MyFuture { tx: Some(tx) }); + drop(handle); + rx.await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_output_destructor_handle_dropped_after_completion() { + use tokio::task; + + struct MyOutput; + impl Drop for MyOutput { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + struct MyFuture { + tx: Option>, + } + impl Future for MyFuture { + type Output = MyOutput; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let _ = self.tx.take().unwrap().send(()); + Poll::Ready(MyOutput) + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + println!("task id: {}", task::id()); + } + } + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(MyFuture { tx: Some(tx) }); + rx.await.unwrap(); + drop(handle); +} + +#[cfg(tokio_unstable)] +#[test] +fn task_try_id_outside_task() { + use tokio::task; + + assert_eq!(None, task::try_id()); +} + +#[cfg(tokio_unstable)] +#[test] +fn task_try_id_inside_block_on() { + use tokio::runtime::Runtime; + use tokio::task; + + let rt = Runtime::new().unwrap(); + rt.block_on(async { + assert_eq!(None, task::try_id()); + }); +} diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 271afb8f5cf..80545bcf706 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -566,6 +566,45 @@ async fn spawn_wakes_localset() { } } +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_spawn_local() { + LocalSet::new() + .run_until(async { + task::spawn_local(async { println!("task id: {}", task::id()) }) + .await + .unwrap(); + }) + .await +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_nested_spawn_local() { + LocalSet::new() + .run_until(async { + task::spawn_local(async { + let outer_id = task::id(); + println!("outer id is {}", outer_id); + LocalSet::new() + .run_until(async { + task::spawn_local(async move { + println!("inner id is {}", task::id()); + assert_ne!(outer_id, task::id()); + }) + .await + .unwrap(); + }) + .await; + println!("back in the outer task"); + assert_eq!(outer_id, task::id()); + }) + .await + .unwrap(); + }) + .await; +} + #[cfg(tokio_unstable)] mod unstable { use tokio::runtime::UnhandledPanic; diff --git a/tokio/tests/task_panic.rs b/tokio/tests/task_panic.rs index e4cedce2798..f0ee8621101 100644 --- a/tokio/tests/task_panic.rs +++ b/tokio/tests/task_panic.rs @@ -121,3 +121,34 @@ fn local_key_get_panic_caller() -> Result<(), Box> { Ok(()) } + +#[cfg(tokio_unstable)] +#[test] +fn task_id_outside_task_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = task::id(); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[cfg(tokio_unstable)] +#[test] +fn task_id_inside_block_on_panic_caller() -> Result<(), Box> { + use tokio::runtime::Runtime; + + let panic_location_file = test_panic(|| { + let rt = Runtime::new().unwrap(); + rt.block_on(async { + task::id(); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +}