Skip to content

Commit

Permalink
rebase and rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
agayev committed Nov 10, 2022
1 parent 9e3fb16 commit 9a67b14
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tokio/src/runtime/context.rs
@@ -1,4 +1,5 @@
use crate::runtime::coop;
use crate::runtime::task::Id;

use std::cell::Cell;

Expand All @@ -17,6 +18,7 @@ struct Context {
/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,
current_task_id: Cell<Option<Id>>,

/// Tracks if the current thread is currently driving a runtime.
/// Note, that if this is set to "entered", the current scheduler
Expand All @@ -41,6 +43,7 @@ tokio_thread_local! {
/// accessing drivers, etc...
#[cfg(feature = "rt")]
handle: RefCell::new(None),
current_task_id: Cell::new(None),

/// Tracks if the current thread is currently driving a runtime.
/// Note, that if this is set to "entered", the current scheduler
Expand Down Expand Up @@ -107,6 +110,15 @@ cfg_rt! {

pub(crate) struct DisallowBlockInPlaceGuard(bool);

pub(crate) fn set_current_task_id(id: Option<Id>) -> Option<Id> {
CONTEXT.try_with(|ctx| ctx.current_task_id.replace(id)).unwrap_or(None)
}

#[track_caller]
pub(crate) fn current_task_id() -> Option<Id> {
CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None)
}

pub(crate) fn try_current() -> Result<scheduler::Handle, TryCurrentError> {
match CONTEXT.try_with(|ctx| ctx.handle.borrow().clone()) {
Ok(Some(handle)) => Ok(handle),
Expand Down
21 changes: 21 additions & 0 deletions tokio/src/runtime/task/core.rs
Expand Up @@ -11,6 +11,7 @@

use crate::future::Future;
use crate::loom::cell::UnsafeCell;
use crate::runtime::context;
use crate::runtime::task::raw::{self, Vtable};
use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
Expand Down Expand Up @@ -157,6 +158,24 @@ impl<T: Future> CoreStage<T> {
}
}

/// Set and clear the task id in the context when the future is executed or
/// dropped, or when the output produced by the future is dropped.
pub(crate) struct TaskIdGuard {
parent_task_id: Option<Id>,
}
impl TaskIdGuard {
fn enter(id: Id) -> Self {
TaskIdGuard {
parent_task_id: context::set_current_task_id(Some(id)),
}
}
}
impl Drop for TaskIdGuard {
fn drop(&mut self) {
context::set_current_task_id(self.parent_task_id);
}
}

impl<T: Future, S: Schedule> Core<T, S> {
/// Polls the future.
///
Expand All @@ -183,6 +202,7 @@ impl<T: Future, S: Schedule> Core<T, S> {
// Safety: The caller ensures the future is pinned.
let future = unsafe { Pin::new_unchecked(future) };

let _guard = TaskIdGuard::enter(self.task_id);
future.poll(&mut cx)
})
};
Expand All @@ -202,6 +222,7 @@ impl<T: Future, S: Schedule> Core<T, S> {
pub(super) fn drop_future_or_output(&self) {
// Safety: the caller ensures mutual exclusion to the field.
unsafe {
let _guard = TaskIdGuard::enter(self.task_id);
self.set_stage(Stage::Consumed);
}
}
Expand Down
26 changes: 24 additions & 2 deletions tokio/src/runtime/task/mod.rs
Expand Up @@ -201,10 +201,32 @@ use std::{fmt, mem};
/// [unstable]: crate#unstable-features
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well...
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub struct Id(u64);

/// Returns the `Id` of the task.
///
/// # Panics
///
/// This function panics if called from outside a task or if called from a
/// future passed to `block_on` call.
///
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[track_caller]
pub fn id() -> Id {
use crate::runtime::context;
context::current_task_id().expect("Can't get a task id when not inside a task")
}

/// Returns the `Id` of the task if called from inside the task, or None otherwise.
///
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[track_caller]
pub fn try_id() -> Option<Id> {
use crate::runtime::context;
context::current_task_id()
}

/// An owned handle to the task, tracked by ref count.
#[repr(transparent)]
pub(crate) struct Task<S: 'static> {
Expand Down
2 changes: 2 additions & 0 deletions tokio/src/task/mod.rs
Expand Up @@ -319,6 +319,8 @@ cfg_rt! {

cfg_unstable! {
pub use crate::runtime::task::Id;
pub use crate::runtime::task::id;
pub use crate::runtime::task::try_id;
}

cfg_trace! {
Expand Down
204 changes: 204 additions & 0 deletions tokio/tests/task_local.rs
Expand Up @@ -117,3 +117,207 @@ async fn task_local_available_on_completion_drop() {
assert_eq!(rx.await.unwrap(), 42);
h.await.unwrap();
}

#[tokio::test(flavor = "current_thread")]
async fn task_id_spawn() {
use tokio::task;

tokio::spawn(async { println!("task id: {}", task::id()) })
.await
.unwrap();
}

#[tokio::test(flavor = "current_thread")]
async fn task_id_spawn_blocking() {
use tokio::task;

task::spawn_blocking(|| println!("task id: {}", task::id()))
.await
.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_collision_current_thread() {
use tokio::task;

let handle1 = tokio::spawn(async { task::id() });
let handle2 = tokio::spawn(async { task::id() });

let (id1, id2) = tokio::join!(handle1, handle2);
assert_ne!(id1.unwrap(), id2.unwrap());
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_id_collision_multi_thread() {
use tokio::task;

let handle1 = tokio::spawn(async { task::id() });
let handle2 = tokio::spawn(async { task::id() });

let (id1, id2) = tokio::join!(handle1, handle2);
assert_ne!(id1.unwrap(), id2.unwrap());
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_ids_match_current_thread() {
use tokio::{sync::oneshot, task};

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(async {
let id = rx.await.unwrap();
assert_eq!(id, task::id());
});
tx.send(handle.id()).unwrap();
handle.await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_ids_match_multi_thread() {
use tokio::{sync::oneshot, task};

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(async {
let id = rx.await.unwrap();
assert_eq!(id, task::id());
});
tx.send(handle.id()).unwrap();
handle.await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_id_future_destructor_completion() {
use tokio::task;

struct MyFuture;
impl Future for MyFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
impl Drop for MyFuture {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

tokio::spawn(MyFuture).await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "multi_thread")]
async fn task_id_future_destructor_abort() {
use tokio::task;

struct MyFuture;
impl Future for MyFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Pending
}
}
impl Drop for MyFuture {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

tokio::spawn(MyFuture).abort();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_output_destructor_handle_dropped_before_completion() {
use tokio::task;

struct MyOutput;
impl Drop for MyOutput {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

struct MyFuture {
tx: Option<oneshot::Sender<()>>,
}
impl Future for MyFuture {
type Output = MyOutput;

fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let _ = self.tx.take().unwrap().send(());
Poll::Ready(MyOutput)
}
}
impl Drop for MyFuture {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(MyFuture { tx: Some(tx) });
drop(handle);
rx.await.unwrap();
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_output_destructor_handle_dropped_after_completion() {
use tokio::task;

struct MyOutput;
impl Drop for MyOutput {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

struct MyFuture {
tx: Option<oneshot::Sender<()>>,
}
impl Future for MyFuture {
type Output = MyOutput;

fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let _ = self.tx.take().unwrap().send(());
Poll::Ready(MyOutput)
}
}
impl Drop for MyFuture {
fn drop(&mut self) {
println!("task id: {}", task::id());
}
}

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(MyFuture { tx: Some(tx) });
rx.await.unwrap();
drop(handle);
}

#[cfg(tokio_unstable)]
#[test]
fn task_try_id_outside_task() {
use tokio::task;

assert_eq!(None, task::try_id());
}

#[cfg(tokio_unstable)]
#[test]
fn task_try_id_inside_block_on() {
use tokio::runtime::Runtime;
use tokio::task;

let rt = Runtime::new().unwrap();
rt.block_on(async {
assert_eq!(None, task::try_id());
});
}
39 changes: 39 additions & 0 deletions tokio/tests/task_local_set.rs
Expand Up @@ -566,6 +566,45 @@ async fn spawn_wakes_localset() {
}
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_spawn_local() {
LocalSet::new()
.run_until(async {
task::spawn_local(async { println!("task id: {}", task::id()) })
.await
.unwrap();
})
.await
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn task_id_nested_spawn_local() {
LocalSet::new()
.run_until(async {
task::spawn_local(async {
let outer_id = task::id();
println!("outer id is {}", outer_id);
LocalSet::new()
.run_until(async {
task::spawn_local(async move {
println!("inner id is {}", task::id());
assert_ne!(outer_id, task::id());
})
.await
.unwrap();
})
.await;
println!("back in the outer task");
assert_eq!(outer_id, task::id());
})
.await
.unwrap();
})
.await;
}

#[cfg(tokio_unstable)]
mod unstable {
use tokio::runtime::UnhandledPanic;
Expand Down

0 comments on commit 9a67b14

Please sign in to comment.