Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: define task_local in underlying future Drop #4604

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Expand Up @@ -91,7 +91,7 @@ stats = []
[dependencies]
tokio-macros = { version = "1.7.0", path = "../tokio-macros", optional = true }

pin-project-lite = "0.2.0"
pin-project-lite = "0.2.7"

# Everything else is optional...
bytes = { version = "1.0.0", optional = true }
Expand Down
61 changes: 44 additions & 17 deletions tokio/src/task/task_local.rs
Expand Up @@ -3,6 +3,7 @@ use std::cell::RefCell;
use std::error::Error;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, thread};
Expand Down Expand Up @@ -123,7 +124,7 @@ impl<T: 'static> LocalKey<T> {
TaskLocalFuture {
local: self,
slot: Some(value),
future: f,
future: ManuallyDrop::new(f),
_pinned: PhantomPinned,
}
}
Expand Down Expand Up @@ -152,7 +153,7 @@ impl<T: 'static> LocalKey<T> {
let scope = TaskLocalFuture {
local: self,
slot: Some(value),
future: (),
future: ManuallyDrop::new(()),
_pinned: PhantomPinned,
};
crate::pin!(scope);
Expand Down Expand Up @@ -208,6 +209,19 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
}
}

struct TaskLocalGuard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
prev: Option<T>,
}

impl<T> Drop for TaskLocalGuard<'_, T> {
fn drop(&mut self) {
let value = self.local.inner.with(|c| c.replace(self.prev.take()));
*self.slot = value;
}
}

pin_project! {
/// A future that sets a value `T` of a task local for the future `F` during
/// its execution.
Expand Down Expand Up @@ -237,39 +251,52 @@ pin_project! {
local: &'static LocalKey<T>,
slot: Option<T>,
#[pin]
future: F,
future: ManuallyDrop<F>,
#[pin]
_pinned: PhantomPinned,
}
}

impl<T: 'static, F> TaskLocalFuture<T, F> {
fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R {
struct Guard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
prev: Option<T>,
}
impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
fn drop(this: Pin<&mut Self>) {
let project = this.project();
let val = project.slot.take();

let prev = project.local.inner.with(|c| c.replace(val));

impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
let value = self.local.inner.with(|c| c.replace(self.prev.take()));
*self.slot = value;
let _guard = TaskLocalGuard {
prev,
slot: project.slot,
local: *project.local,
};

unsafe {
drop(project.future.map_unchecked_mut(|fut| {
ManuallyDrop::drop(fut);
fut
}));
}
}
}
}

impl<T: 'static, F> TaskLocalFuture<T, F> {
fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R {
let project = self.project();
let val = project.slot.take();

let prev = project.local.inner.with(|c| c.replace(val));

let _guard = Guard {
let _guard = TaskLocalGuard {
prev,
slot: project.slot,
local: *project.local,
};

f(project.future)
unsafe {
use std::ops::DerefMut;
let fut = project.future.map_unchecked_mut(|f| f.deref_mut());
f(fut)
}
}
}

Expand Down
31 changes: 31 additions & 0 deletions tokio/tests/task_local.rs
Expand Up @@ -31,3 +31,34 @@ async fn local() {
j2.await.unwrap();
j3.await.unwrap();
}

tokio::task_local! {
static KEY: u32;
}

struct Guard(u32);
impl Drop for Guard {
fn drop(&mut self) {
assert_eq!(KEY.try_with(|x| *x), Ok(self.0));
}
}

#[tokio::test(flavor = "multi_thread")]
async fn task_local_available_on_drop() {
let (tx, rx) = tokio::sync::oneshot::channel();

let h = tokio::spawn(KEY.scope(42, async move {
let _g = Guard(42);
let _ = tx.send(());
std::future::pending::<()>().await;
}));

rx.await.unwrap();

h.abort();

let err = h.await.unwrap_err();
if err.is_panic() {
panic!("{}", err);
}
}