Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LocalSet::enter #4765

Merged
merged 14 commits into from Jul 13, 2022
127 changes: 98 additions & 29 deletions tokio/src/task/local.rs
Expand Up @@ -10,6 +10,7 @@ use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;

use pin_project_lite::pin_project;
Expand Down Expand Up @@ -215,7 +216,7 @@ cfg_rt! {
tick: Cell<u8>,

/// State available from thread-local.
context: Context,
context: Rc<Context>,

/// This type should not be Send.
_not_send: PhantomData<*const ()>,
Expand Down Expand Up @@ -260,7 +261,7 @@ pin_project! {
}
}

scoped_thread_local!(static CURRENT: Context);
thread_local!(static CURRENT: Cell<Option<Rc<Context>>> = Cell::new(None));
gftea marked this conversation as resolved.
Show resolved Hide resolved

cfg_rt! {
/// Spawns a `!Send` future on the local task set.
Expand Down Expand Up @@ -310,10 +311,12 @@ cfg_rt! {
F::Output: 'static
{
CURRENT.with(|maybe_cx| {
let cx = maybe_cx
.expect("`spawn_local` called from outside of a `task::LocalSet`");
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}

cx.spawn(future, name)
})
}
}
Expand All @@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61;
/// How often it check the remote queue first.
const REMOTE_FIRST_INTERVAL: u8 = 31;

/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
ctx.replace(self.0.take());
})
}
}

impl fmt::Debug for LocalEnterGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalEnterGuard").finish()
}
}

impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
LocalSet {
tick: Cell::new(0),
context: Context {
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
Expand All @@ -342,11 +362,24 @@ impl LocalSet {
unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
}),
unhandled_panic: Cell::new(false),
},
}),
_not_send: PhantomData,
}
}

/// Enters the context of this `LocalSet`.
///
/// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
/// context you are inside.
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|ctx| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
}

/// Spawns a `!Send` task onto the local task set.
///
/// This task is guaranteed to be run on the current thread.
Expand Down Expand Up @@ -579,7 +612,25 @@ impl LocalSet {
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.set(&self.context, f)
CURRENT.with(|ctx| {
struct Reset<'a> {
ctx_ref: &'a Cell<Option<Rc<Context>>>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

f()
})
}
}

Expand Down Expand Up @@ -645,8 +696,9 @@ cfg_unstable! {
/// [`JoinHandle`]: struct@crate::task::JoinHandle
pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
// TODO: This should be set as a builder
Arc::get_mut(&mut self.context.shared)
.expect("TODO: we shouldn't panic")
Rc::get_mut(&mut self.context)
.and_then(|ctx| Arc::get_mut(&mut ctx.shared))
.expect("Unhandled Panic behavior modified after starting LocalSet")
.unhandled_panic = behavior;
self
}
Expand Down Expand Up @@ -769,23 +821,33 @@ impl<T: Future> Future for RunUntil<'_, T> {
}
}

fn clone_rc<T>(rc: &Cell<Option<Rc<T>>>) -> Option<Rc<T>> {
let value = rc.take();
let cloned = value.clone();
rc.set(value);
cloned
}

impl Shared {
/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| match maybe_cx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();

if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();

if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
}
}
}
});
Expand All @@ -799,9 +861,14 @@ impl Shared {
impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
}
}
})
}

Expand All @@ -821,13 +888,15 @@ impl task::Schedule for Arc<Shared> {
// This hook is only called from within the runtime, so
// `CURRENT` should match with `&self`, i.e. there is no
// opportunity for a nested scheduler to be called.
CURRENT.with(|maybe_cx| match maybe_cx {
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
cx.unhandled_panic.set(true);
cx.owned.close_and_shutdown_all();
}
_ => unreachable!("runtime core not set in CURRENT thread-local"),
})
}})
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/task/mod.rs
Expand Up @@ -297,7 +297,7 @@ cfg_rt! {
}

mod local;
pub use local::{spawn_local, LocalSet};
pub use local::{spawn_local, LocalSet, LocalEnterGuard};

mod task_local;
pub use task_local::LocalKey;
Expand Down
15 changes: 15 additions & 0 deletions tokio/tests/task_local_set.rs
Expand Up @@ -126,6 +126,21 @@ async fn local_threadpool_timer() {
})
.await;
}
#[test]
fn enter_guard_spawn() {
let local = LocalSet::new();
let _guard = local.enter();
// Run the local task set.

let join = task::spawn_local(async { true });
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
local.block_on(&rt, async move {
assert!(join.await.unwrap());
});
}

#[test]
// This will panic, since the thread that calls `block_on` cannot use
Expand Down