Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LocalSet::enter #4765

Merged
merged 14 commits into from Jul 13, 2022
72 changes: 58 additions & 14 deletions tokio/src/task/local.rs
Expand Up @@ -4,12 +4,13 @@ use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::sync::AtomicWaker;
use crate::util::VecDequeCell;

use std::cell::Cell;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;

use pin_project_lite::pin_project;
Expand Down Expand Up @@ -215,7 +216,7 @@ cfg_rt! {
tick: Cell<u8>,

/// State available from thread-local.
context: Context,
context: Rc<Context>,

/// This type should not be Send.
_not_send: PhantomData<*const ()>,
Expand Down Expand Up @@ -252,7 +253,7 @@ pin_project! {
}
}

scoped_thread_local!(static CURRENT: Context);
thread_local!(static CURRENT: RefCell<Option<Rc<Context>>> = RefCell::new(None));

cfg_rt! {
/// Spawns a `!Send` future on the local task set.
Expand Down Expand Up @@ -302,10 +303,11 @@ cfg_rt! {
F::Output: 'static
{
CURRENT.with(|maybe_cx| {
let cx = maybe_cx
.expect("`spawn_local` called from outside of a `task::LocalSet`");
match maybe_cx.borrow().as_ref() {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}

cx.spawn(future, name)
})
}
}
Expand All @@ -319,23 +321,44 @@ const MAX_TASKS_PER_TICK: usize = 61;
/// How often it check the remote queue first.
const REMOTE_FIRST_INTERVAL: u8 = 31;

/// Context guard for LocalSet
#[allow(missing_debug_implementations)]
pub struct LocalEnterGuard(Option<Rc<Context>>);
gftea marked this conversation as resolved.
Show resolved Hide resolved

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
// *ctx.borrow_mut() = self.0.take();
ctx.replace(self.0.take());
})
}
}

impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
LocalSet {
tick: Cell::new(0),
context: Context {
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
waker: AtomicWaker::new(),
}),
},
}),
_not_send: PhantomData,
}
}

/// Enter current LocalSet context
pub fn enter(&self) -> LocalEnterGuard {
gftea marked this conversation as resolved.
Show resolved Hide resolved
CURRENT.with(|ctx| {
let old = ctx.borrow_mut().replace(self.context.clone());
LocalEnterGuard(old)
})
}

/// Spawns a `!Send` task onto the local task set.
///
/// This task is guaranteed to be run on the current thread.
Expand Down Expand Up @@ -563,7 +586,26 @@ impl LocalSet {
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.set(&self.context, f)
// CURRENT.set(&self.context, f)
CURRENT.with(|ctx| {
struct Reset<'a> {
ctx_ref: &'a RefCell<Option<Rc<Context>>>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.borrow_mut().replace(self.context.clone());

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

f()
})
}
}

Expand Down Expand Up @@ -686,7 +728,7 @@ impl<T: Future> Future for RunUntil<'_, T> {
impl Shared {
/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| match maybe_cx {
CURRENT.with(|maybe_cx| match maybe_cx.borrow().as_ref() {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
gftea marked this conversation as resolved.
Show resolved Hide resolved
}
Expand All @@ -712,10 +754,12 @@ impl Shared {

impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
CURRENT.with(|maybe_cx| match maybe_cx.borrow().as_ref() {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
}
})
}

Expand Down
2 changes: 1 addition & 1 deletion tokio/src/task/mod.rs
Expand Up @@ -297,7 +297,7 @@ cfg_rt! {
}

mod local;
pub use local::{spawn_local, LocalSet};
pub use local::{spawn_local, LocalSet, LocalEnterGuard};

mod task_local;
pub use task_local::LocalKey;
Expand Down
15 changes: 15 additions & 0 deletions tokio/tests/task_local_set.rs
Expand Up @@ -126,6 +126,21 @@ async fn local_threadpool_timer() {
})
.await;
}
#[test]
fn enter_guard_spawn() {
let local = LocalSet::new();
let _guard = local.enter();
// Run the local task set.

let join = task::spawn_local(async { true });
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
local.block_on(&rt, async move {
assert!(join.await.unwrap());
});
}

#[test]
// This will panic, since the thread that calls `block_on` cannot use
Expand Down