From ea414c3ba451313193d6e219f7ea029d55ae4fd3 Mon Sep 17 00:00:00 2001 From: gftea Date: Fri, 17 Jun 2022 23:08:29 +0200 Subject: [PATCH] Use thread_local to reference LocalSet's context instead of scoped_thread_local (#4764) --- tokio/src/task/local.rs | 73 +++++++++++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 2fa01197180..7334fbeda64 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -4,12 +4,13 @@ use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; use crate::util::VecDequeCell; -use std::cell::Cell; +use std::cell::{Cell, RefCell}; use std::collections::VecDeque; use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; +use std::rc::Rc; use std::task::Poll; use pin_project_lite::pin_project; @@ -215,7 +216,7 @@ cfg_rt! { tick: Cell, /// State available from thread-local. - context: Context, + context: Rc, /// This type should not be Send. _not_send: PhantomData<*const ()>, @@ -252,7 +253,7 @@ pin_project! { } } -scoped_thread_local!(static CURRENT: Context); +thread_local!(static CURRENT: RefCell>> = RefCell::new(None)); cfg_rt! { /// Spawns a `!Send` future on the local task set. @@ -302,10 +303,11 @@ cfg_rt! { F::Output: 'static { CURRENT.with(|maybe_cx| { - let cx = maybe_cx - .expect("`spawn_local` called from outside of a `task::LocalSet`"); + match maybe_cx.borrow().as_ref() { + None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), + Some(cx) => cx.spawn(future, name) + } - cx.spawn(future, name) }) } } @@ -319,9 +321,17 @@ const MAX_TASKS_PER_TICK: usize = 61; /// How often it check the remote queue first. const REMOTE_FIRST_INTERVAL: u8 = 31; -#[derive(Debug)] -pub struct LocalEnterGuard<'a> { - _guard: &'a LocalSet, +/// Context guard for LocalSet +#[allow(missing_debug_implementations)] +pub struct LocalEnterGuard(Option>); + +impl Drop for LocalEnterGuard { + fn drop(&mut self) { + CURRENT.with(|ctx| { + // *ctx.borrow_mut() = self.0.take(); + ctx.replace(self.0.take()); + }) + } } impl LocalSet { @@ -329,23 +339,23 @@ impl LocalSet { pub fn new() -> LocalSet { LocalSet { tick: Cell::new(0), - context: Context { + context: Rc::new(Context { owned: LocalOwnedTasks::new(), queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), }), - }, + }), _not_send: PhantomData, } } /// Enter current LocalSet context - pub fn enter(&self) -> LocalEnterGuard<'_> { - CURRENT.inner.with(|c| { - c.set(&self.context as *const _ as *const ()); - LocalEnterGuard { _guard: &self } + pub fn enter(&self) -> LocalEnterGuard { + CURRENT.with(|ctx| { + let old = ctx.borrow_mut().replace(self.context.clone()); + LocalEnterGuard(old) }) } @@ -576,7 +586,26 @@ impl LocalSet { } fn with(&self, f: impl FnOnce() -> T) -> T { - CURRENT.set(&self.context, f) + // CURRENT.set(&self.context, f) + CURRENT.with(|ctx| { + struct Reset<'a> { + ctx_ref: &'a RefCell>>, + val: Option>, + } + impl<'a> Drop for Reset<'a> { + fn drop(&mut self) { + self.ctx_ref.replace(self.val.take()); + } + } + let old = ctx.borrow_mut().replace(self.context.clone()); + + let _reset = Reset { + ctx_ref: ctx, + val: old, + }; + + f() + }) } } @@ -699,7 +728,7 @@ impl Future for RunUntil<'_, T> { impl Shared { /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { - CURRENT.with(|maybe_cx| match maybe_cx { + CURRENT.with(|maybe_cx| match maybe_cx.borrow().as_ref() { Some(cx) if cx.shared.ptr_eq(self) => { cx.queue.push_back(task); } @@ -725,10 +754,12 @@ impl Shared { impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - assert!(cx.shared.ptr_eq(self)); - cx.owned.remove(task) + CURRENT.with(|maybe_cx| match maybe_cx.borrow().as_ref() { + None => panic!("scheduler context missing"), + Some(cx) => { + assert!(cx.shared.ptr_eq(self)); + cx.owned.remove(task) + } }) }