Skip to content

Commit

Permalink
task: add LocalSet::enter (#4736) (#4765)
Browse files Browse the repository at this point in the history
  • Loading branch information
gftea committed Jul 13, 2022
1 parent 8e20cfb commit 14fca34
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 30 deletions.
127 changes: 98 additions & 29 deletions tokio/src/task/local.rs
Expand Up @@ -10,6 +10,7 @@ use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;

use pin_project_lite::pin_project;
Expand Down Expand Up @@ -215,7 +216,7 @@ cfg_rt! {
tick: Cell<u8>,

/// State available from thread-local.
context: Context,
context: Rc<Context>,

/// This type should not be Send.
_not_send: PhantomData<*const ()>,
Expand Down Expand Up @@ -260,7 +261,7 @@ pin_project! {
}
}

scoped_thread_local!(static CURRENT: Context);
thread_local!(static CURRENT: Cell<Option<Rc<Context>>> = Cell::new(None));

cfg_rt! {
/// Spawns a `!Send` future on the local task set.
Expand Down Expand Up @@ -310,10 +311,12 @@ cfg_rt! {
F::Output: 'static
{
CURRENT.with(|maybe_cx| {
let cx = maybe_cx
.expect("`spawn_local` called from outside of a `task::LocalSet`");
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}

cx.spawn(future, name)
})
}
}
Expand All @@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61;
/// How often it check the remote queue first.
const REMOTE_FIRST_INTERVAL: u8 = 31;

/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
ctx.replace(self.0.take());
})
}
}

impl fmt::Debug for LocalEnterGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalEnterGuard").finish()
}
}

impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
LocalSet {
tick: Cell::new(0),
context: Context {
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
Expand All @@ -342,11 +362,24 @@ impl LocalSet {
unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
}),
unhandled_panic: Cell::new(false),
},
}),
_not_send: PhantomData,
}
}

/// Enters the context of this `LocalSet`.
///
/// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
/// context you are inside.
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|ctx| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
}

/// Spawns a `!Send` task onto the local task set.
///
/// This task is guaranteed to be run on the current thread.
Expand Down Expand Up @@ -579,7 +612,25 @@ impl LocalSet {
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.set(&self.context, f)
CURRENT.with(|ctx| {
struct Reset<'a> {
ctx_ref: &'a Cell<Option<Rc<Context>>>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

f()
})
}
}

Expand Down Expand Up @@ -645,8 +696,9 @@ cfg_unstable! {
/// [`JoinHandle`]: struct@crate::task::JoinHandle
pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
// TODO: This should be set as a builder
Arc::get_mut(&mut self.context.shared)
.expect("TODO: we shouldn't panic")
Rc::get_mut(&mut self.context)
.and_then(|ctx| Arc::get_mut(&mut ctx.shared))
.expect("Unhandled Panic behavior modified after starting LocalSet")
.unhandled_panic = behavior;
self
}
Expand Down Expand Up @@ -769,23 +821,33 @@ impl<T: Future> Future for RunUntil<'_, T> {
}
}

fn clone_rc<T>(rc: &Cell<Option<Rc<T>>>) -> Option<Rc<T>> {
let value = rc.take();
let cloned = value.clone();
rc.set(value);
cloned
}

impl Shared {
/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| match maybe_cx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();

if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();

if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
}
}
}
});
Expand All @@ -799,9 +861,14 @@ impl Shared {
impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
}
}
})
}

Expand All @@ -821,13 +888,15 @@ impl task::Schedule for Arc<Shared> {
// This hook is only called from within the runtime, so
// `CURRENT` should match with `&self`, i.e. there is no
// opportunity for a nested scheduler to be called.
CURRENT.with(|maybe_cx| match maybe_cx {
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
cx.unhandled_panic.set(true);
cx.owned.close_and_shutdown_all();
}
_ => unreachable!("runtime core not set in CURRENT thread-local"),
})
}})
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/task/mod.rs
Expand Up @@ -299,7 +299,7 @@ cfg_rt! {
}

mod local;
pub use local::{spawn_local, LocalSet};
pub use local::{spawn_local, LocalSet, LocalEnterGuard};

mod task_local;
pub use task_local::LocalKey;
Expand Down
15 changes: 15 additions & 0 deletions tokio/tests/task_local_set.rs
Expand Up @@ -135,6 +135,21 @@ async fn local_threadpool_timer() {
})
.await;
}
#[test]
fn enter_guard_spawn() {
let local = LocalSet::new();
let _guard = local.enter();
// Run the local task set.

let join = task::spawn_local(async { true });
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
local.block_on(&rt, async move {
assert!(join.await.unwrap());
});
}

#[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
#[test]
Expand Down

0 comments on commit 14fca34

Please sign in to comment.