Skip to content

Commit

Permalink
addressed issues raised by @hawkw
Browse files Browse the repository at this point in the history
  • Loading branch information
agayev committed Nov 9, 2022
1 parent 370e754 commit 6d190b7
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 11 deletions.
7 changes: 2 additions & 5 deletions tokio/src/runtime/context.rs
Expand Up @@ -109,11 +109,8 @@ cfg_rt! {
}

#[track_caller]
pub(crate) fn current_task_id() -> Id {
match CONTEXT.try_with(|ctx| ctx.current_task_id.get()) {
Ok(Some(id)) => id,
_ => panic!("can't get a task id when not inside a task"),
}
pub(crate) fn current_task_id() -> Option<Id> {
CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None)
}

pub(crate) fn try_current() -> Result<scheduler::Handle, TryCurrentError> {
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/task/core.rs
Expand Up @@ -201,7 +201,7 @@ impl<T: Future> CoreStage<T> {
// Safety: the caller ensures mutual exclusion to the field.
unsafe {
let _task_id_guard = self.stage.with_mut(|ptr| match &*ptr {
Stage::Finished(Ok(_), id) => Some(TaskIdGuard::new(*id)),
Stage::Finished(Ok(_), id) => Some(TaskIdGuard::enter(*id)),
_ => None,
});
self.set_stage(Stage::Consumed);
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/task/harness.rs
Expand Up @@ -460,7 +460,7 @@ enum PollFuture {
fn cancel_task<T: Future>(stage: &CoreStage<T>, id: super::Id) {
// Drop the future from a panic guard.
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let _task_id_guard = TaskIdGuard::new(id);
let _task_id_guard = TaskIdGuard::enter(id);
stage.drop_future_or_output();
}));

Expand All @@ -483,7 +483,7 @@ fn poll_future<T: Future, S: Schedule>(
cx: Context<'_>,
) -> Poll<()> {
// Poll the future.
let _task_id_guard = TaskIdGuard::new(id);
let _task_id_guard = TaskIdGuard::enter(id);
let output = panic::catch_unwind(panic::AssertUnwindSafe(|| {
struct Guard<'a, T: Future> {
core: &'a CoreStage<T>,
Expand Down
11 changes: 10 additions & 1 deletion tokio/src/runtime/task/mod.rs
Expand Up @@ -216,6 +216,15 @@ use crate::runtime::context;
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[track_caller]
pub fn id() -> Id {
use crate::runtime::context;
context::current_task_id().expect("Can't get a task id when not inside a task")
}

/// Returns the `Id` of the task if called from inside the task, or None otherwise.
///
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[track_caller]
pub fn try_id() -> Option<Id> {
use crate::runtime::context;
context::current_task_id()
}
Expand All @@ -224,7 +233,7 @@ pub fn id() -> Id {
/// and cancellation.
pub(crate) struct TaskIdGuard {}
impl TaskIdGuard {
fn new(id: Id) -> Self {
fn enter(id: Id) -> Self {
context::set_current_task_id(Some(id));
TaskIdGuard {}
}
Expand Down
1 change: 1 addition & 0 deletions tokio/src/task/mod.rs
Expand Up @@ -320,6 +320,7 @@ cfg_rt! {
cfg_unstable! {
pub use crate::runtime::task::Id;
pub use crate::runtime::task::id;
pub use crate::runtime::task::try_id;
}

cfg_trace! {
Expand Down
24 changes: 22 additions & 2 deletions tokio/tests/task_local.rs
Expand Up @@ -122,7 +122,7 @@ async fn task_local_available_on_completion_drop() {
async fn task_id_spawn() {
use tokio::task;

task::spawn_blocking(|| println!("task id: {}", task::id()))
tokio::spawn(async { println!("task id: {}", task::id()) })
.await
.unwrap();
}
Expand All @@ -131,7 +131,7 @@ async fn task_id_spawn() {
async fn task_id_spawn_blocking() {
use tokio::task;

tokio::spawn(async { println!("task id: {}", task::id()) })
task::spawn_blocking(|| println!("task id: {}", task::id()))
.await
.unwrap();
}
Expand Down Expand Up @@ -301,3 +301,23 @@ async fn task_id_output_destructor_handle_dropped_after_completion() {
rx.await.unwrap();
drop(handle);
}

#[cfg(tokio_unstable)]
#[test]
fn task_try_id_outside_task() {
use tokio::task;

assert_eq!(None, task::try_id());
}

#[cfg(tokio_unstable)]
#[test]
fn task_try_id_inside_block_on() {
use tokio::runtime::Runtime;
use tokio::task;

let rt = Runtime::new().unwrap();
rt.block_on(async {
assert_eq!(None, task::try_id());
});
}

0 comments on commit 6d190b7

Please sign in to comment.