diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 90a32d98164..6b4da18f7b4 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -91,7 +91,7 @@ stats = [] [dependencies] tokio-macros = { version = "1.7.0", path = "../tokio-macros", optional = true } -pin-project-lite = "0.2.0" +pin-project-lite = "0.2.7" # Everything else is optional... bytes = { version = "1.0.0", optional = true } diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs index 949bbca3eee..5114c3edc70 100644 --- a/tokio/src/task/task_local.rs +++ b/tokio/src/task/task_local.rs @@ -3,6 +3,7 @@ use std::cell::RefCell; use std::error::Error; use std::future::Future; use std::marker::PhantomPinned; +use std::mem::ManuallyDrop; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, thread}; @@ -123,7 +124,7 @@ impl LocalKey { TaskLocalFuture { local: self, slot: Some(value), - future: f, + future: ManuallyDrop::new(f), _pinned: PhantomPinned, } } @@ -152,7 +153,7 @@ impl LocalKey { let scope = TaskLocalFuture { local: self, slot: Some(value), - future: (), + future: ManuallyDrop::new(()), _pinned: PhantomPinned, }; crate::pin!(scope); @@ -208,6 +209,19 @@ impl fmt::Debug for LocalKey { } } +struct TaskLocalGuard<'a, T: 'static> { + local: &'static LocalKey, + slot: &'a mut Option, + prev: Option, +} + +impl Drop for TaskLocalGuard<'_, T> { + fn drop(&mut self) { + let value = self.local.inner.with(|c| c.replace(self.prev.take())); + *self.slot = value; + } +} + pin_project! { /// A future that sets a value `T` of a task local for the future `F` during /// its execution. @@ -237,39 +251,52 @@ pin_project! { local: &'static LocalKey, slot: Option, #[pin] - future: F, + future: ManuallyDrop, #[pin] _pinned: PhantomPinned, } -} -impl TaskLocalFuture { - fn with_task) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { - struct Guard<'a, T: 'static> { - local: &'static LocalKey, - slot: &'a mut Option, - prev: Option, - } + impl PinnedDrop for TaskLocalFuture { + fn drop(this: Pin<&mut Self>) { + let project = this.project(); + let val = project.slot.take(); + + let prev = project.local.inner.with(|c| c.replace(val)); - impl Drop for Guard<'_, T> { - fn drop(&mut self) { - let value = self.local.inner.with(|c| c.replace(self.prev.take())); - *self.slot = value; + let _guard = TaskLocalGuard { + prev, + slot: project.slot, + local: *project.local, + }; + + unsafe { + drop(project.future.map_unchecked_mut(|fut| { + ManuallyDrop::drop(fut); + fut + })); } } + } +} +impl TaskLocalFuture { + fn with_task) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { let project = self.project(); let val = project.slot.take(); let prev = project.local.inner.with(|c| c.replace(val)); - let _guard = Guard { + let _guard = TaskLocalGuard { prev, slot: project.slot, local: *project.local, }; - f(project.future) + unsafe { + use std::ops::DerefMut; + let fut = project.future.map_unchecked_mut(|f| f.deref_mut()); + f(fut) + } } } diff --git a/tokio/tests/task_local.rs b/tokio/tests/task_local.rs index 811d63ea0f8..706f2b9d974 100644 --- a/tokio/tests/task_local.rs +++ b/tokio/tests/task_local.rs @@ -31,3 +31,34 @@ async fn local() { j2.await.unwrap(); j3.await.unwrap(); } + +tokio::task_local! { + static KEY: u32; +} + +struct Guard(u32); +impl Drop for Guard { + fn drop(&mut self) { + assert_eq!(KEY.try_with(|x| *x), Ok(self.0)); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn task_local_available_on_drop() { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let h = tokio::spawn(KEY.scope(42, async move { + let _g = Guard(42); + let _ = tx.send(()); + std::future::pending::<()>().await; + })); + + rx.await.unwrap(); + + h.abort(); + + let err = h.await.unwrap_err(); + if err.is_panic() { + panic!("{}", err); + } +}