Skip to content

Commit

Permalink
task: wake local tasks to the local queue when woken by the same thre…
Browse files Browse the repository at this point in the history
…ad (#5095)

Motivation

Currently, when a task spawned on a `LocalSet` is woken by an I/O driver
or time driver running on the same thread as the `LocalSet`, the task is
pushed to the `LocalSet`'s locked remote run queue rather than to its
unsynchronized local run queue. This is unfortunate, as it
negates some of the performance benefits of having an unsynchronized
local run queue. Instead, tasks are only woken to the local queue when
they are woken by other tasks also running on the local set.

This occurs because the local queue is only used when the `CONTEXT`
thread-local contains a Context that's the same as the task's
`Schedule` instance (an `Arc<Shared>`)'s Context. When the `LocalSet`
is not being polled, the thread-local is unset, and the local run queue
cannot be accessed by the `Schedule` implementation for `Arc<Shared>`.

Solution

This branch fixes this by moving the local run queue into Shared along
with the remote run queue. When an `Arc<Shared>`'s Schedule impl wakes
a task and the `CONTEXT` thread-local is None (indicating we are not
currently polling the LocalSet on this thread), we now check if the
current thread's `ThreadId` matches that of the thread the `LocalSet`
was created on, and push the woken task to the local queue if it was.

Moving the local run queue into `Shared` is somewhat unfortunate, as it
means we now have a single field on the `Shared` type, which must not be
accessed from other threads and must add an unsafe impl `Sync` for `Shared`.
However, it's the only viable way to wake to the local queue
from the Schedule impl for `Arc<Shared>`, so I figured it was worth
the additional unsafe code. I added a debug assertion to check that the
local queue is only accessed from the thread that owns the `LocalSet`.
  • Loading branch information
hawkw committed Oct 13, 2022
1 parent ca9dd72 commit 23a1ccf
Showing 1 changed file with 172 additions and 22 deletions.
194 changes: 172 additions & 22 deletions tokio/src/task/local.rs
@@ -1,5 +1,6 @@
//! Runs `!Send` futures on the current thread.
use crate::loom::sync::{Arc, Mutex};
use crate::loom::thread::{self, ThreadId};
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::sync::AtomicWaker;
use crate::util::{RcCell, VecDequeCell};
Expand Down Expand Up @@ -228,9 +229,6 @@ struct Context {
/// Collection of all active tasks spawned onto this executor.
owned: LocalOwnedTasks<Arc<Shared>>,

/// Local run queue sender and receiver.
queue: VecDequeCell<task::Notified<Arc<Shared>>>,

/// State shared between threads.
shared: Arc<Shared>,

Expand All @@ -241,6 +239,19 @@ struct Context {

/// LocalSet state shared between threads.
struct Shared {
/// Local run queue sender and receiver.
///
/// # Safety
///
/// This field must *only* be accessed from the thread that owns the
/// `LocalSet` (i.e., `Thread::current().id() == owner`).
local_queue: VecDequeCell<task::Notified<Arc<Shared>>>,

/// The `ThreadId` of the thread that owns the `LocalSet`.
///
/// Since `LocalSet` is `!Send`, this will never change.
owner: ThreadId,

/// Remote run queue sender.
queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>,

Expand All @@ -262,10 +273,21 @@ pin_project! {
}

#[cfg(any(loom, tokio_no_const_thread_local))]
thread_local!(static CURRENT: RcCell<Context> = RcCell::new());
thread_local!(static CURRENT: LocalData = LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
});

#[cfg(not(any(loom, tokio_no_const_thread_local)))]
thread_local!(static CURRENT: RcCell<Context> = const { RcCell::new() });
thread_local!(static CURRENT: LocalData = const { LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
} });

struct LocalData {
thread_id: Cell<Option<ThreadId>>,
ctx: RcCell<Context>,
}

cfg_rt! {
/// Spawns a `!Send` future on the local task set.
Expand Down Expand Up @@ -314,7 +336,7 @@ cfg_rt! {
where F: Future + 'static,
F::Output: 'static
{
match CURRENT.with(|maybe_cx| maybe_cx.get()) {
match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}
Expand All @@ -335,7 +357,7 @@ pub struct LocalEnterGuard(Option<Rc<Context>>);

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
ctx.set(self.0.take());
})
}
Expand All @@ -354,8 +376,9 @@ impl LocalSet {
tick: Cell::new(0),
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
local_queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
owner: thread::current().id(),
queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
waker: AtomicWaker::new(),
#[cfg(tokio_unstable)]
Expand All @@ -374,7 +397,7 @@ impl LocalSet {
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
Expand Down Expand Up @@ -597,9 +620,9 @@ impl LocalSet {
.lock()
.as_mut()
.and_then(|queue| queue.pop_front())
.or_else(|| self.context.queue.pop_front())
.or_else(|| self.pop_local())
} else {
self.context.queue.pop_front().or_else(|| {
self.pop_local().or_else(|| {
self.context
.shared
.queue
Expand All @@ -612,8 +635,17 @@ impl LocalSet {
task.map(|task| self.context.owned.assert_owner(task))
}

fn pop_local(&self) -> Option<task::Notified<Arc<Shared>>> {
unsafe {
// Safety: because the `LocalSet` itself is `!Send`, we know we are
// on the same thread if we have access to the `LocalSet`, and can
// therefore access the local run queue.
self.context.shared.local_queue().pop_front()
}
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
Expand All @@ -639,7 +671,7 @@ impl LocalSet {
fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
let mut f = Some(f);

let res = CURRENT.try_with(|ctx| {
let res = CURRENT.try_with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
Expand Down Expand Up @@ -782,7 +814,21 @@ impl Drop for LocalSet {

// We already called shutdown on all tasks above, so there is no
// need to call shutdown.
for task in self.context.queue.take() {

// Safety: note that this *intentionally* bypasses the unsafe
// `Shared::local_queue()` method. This is in order to avoid the
// debug assertion that we are on the thread that owns the
// `LocalSet`, because on some systems (e.g. at least some macOS
// versions), attempting to get the current thread ID can panic due
// to the thread's local data that stores the thread ID being
// dropped *before* the `LocalSet`.
//
// Despite avoiding the assertion here, it is safe for us to access
// the local queue in `Drop`, because the `LocalSet` itself is
// `!Send`, so we can reasonably guarantee that it will not be
// `Drop`ped from another thread.
let local_queue = self.context.shared.local_queue.take();
for task in local_queue {
drop(task);
}

Expand Down Expand Up @@ -854,15 +900,48 @@ impl<T: Future> Future for RunUntil<'_, T> {
}

impl Shared {
/// # Safety
///
/// This is safe to call if and ONLY if we are on the thread that owns this
/// `LocalSet`.
unsafe fn local_queue(&self) -> &VecDequeCell<task::Notified<Arc<Self>>> {
debug_assert!(
// if we couldn't get the thread ID because we're dropping the local
// data, skip the assertion --- the `Drop` impl is not going to be
// called from another thread, because `LocalSet` is `!Send`
thread_id().map(|id| id == self.owner).unwrap_or(true),
"`LocalSet`'s local run queue must not be accessed by another thread!"
);
&self.local_queue
}

/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| {
match maybe_cx.get() {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
CURRENT.with(|localdata| {
match localdata.ctx.get() {
Some(cx) if cx.shared.ptr_eq(self) => unsafe {
// Safety: if the current `LocalSet` context points to this
// `LocalSet`, then we are on the thread that owns it.
cx.shared.local_queue().push_back(task);
},

// We are on the thread that owns the `LocalSet`, so we can
// wake to the local queue.
_ if localdata.get_or_insert_id() == self.owner => {
unsafe {
// Safety: we just checked that the thread ID matches
// the localset's owner, so this is safe.
self.local_queue().push_back(task);
}
// We still have to wake the `LocalSet`, because it isn't
// currently being polled.
self.waker.wake();
}

// We are *not* on the thread that owns the `LocalSet`, so we
// have to wake to the remote queue.
_ => {
// First check whether the queue is still there (if not, the
// First, check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();
Expand All @@ -882,9 +961,13 @@ impl Shared {
}
}

// This is safe because (and only because) we *pinky pwomise* to never touch the
// local run queue except from the thread that owns the `LocalSet`.
unsafe impl Sync for Shared {}

impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| match maybe_cx.get() {
CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
Expand All @@ -909,7 +992,7 @@ impl task::Schedule for Arc<Shared> {
// This hook is only called from within the runtime, so
// `CURRENT` should match with `&self`, i.e. there is no
// opportunity for a nested scheduler to be called.
CURRENT.with(|maybe_cx| match maybe_cx.get() {
CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
cx.unhandled_panic.set(true);
cx.owned.close_and_shutdown_all();
Expand All @@ -922,9 +1005,31 @@ impl task::Schedule for Arc<Shared> {
}
}

#[cfg(test)]
impl LocalData {
fn get_or_insert_id(&self) -> ThreadId {
self.thread_id.get().unwrap_or_else(|| {
let id = thread::current().id();
self.thread_id.set(Some(id));
id
})
}
}

fn thread_id() -> Option<ThreadId> {
CURRENT
.try_with(|localdata| localdata.get_or_insert_id())
.ok()
}

#[cfg(all(test, not(loom)))]
mod tests {
use super::*;

// Does a `LocalSet` running on a current-thread runtime...basically work?
//
// This duplicates a test in `tests/task_local_set.rs`, but because this is
// a lib test, it wil run under Miri, so this is necessary to catch stacked
// borrows violations in the `LocalSet` implementation.
#[test]
fn local_current_thread_scheduler() {
let f = async {
Expand All @@ -939,4 +1044,49 @@ mod tests {
.expect("rt")
.block_on(f)
}

// Tests that when a task on a `LocalSet` is woken by an io driver on the
// same thread, the task is woken to the localset's local queue rather than
// its remote queue.
//
// This test has to be defined in the `local.rs` file as a lib test, rather
// than in `tests/`, because it makes assertions about the local set's
// internal state.
#[test]
fn wakes_to_local_queue() {
use super::*;
use crate::sync::Notify;
let rt = crate::runtime::Builder::new_current_thread()
.build()
.expect("rt");
rt.block_on(async {
let local = LocalSet::new();
let notify = Arc::new(Notify::new());
let task = local.spawn_local({
let notify = notify.clone();
async move {
notify.notified().await;
}
});
let mut run_until = Box::pin(local.run_until(async move {
task.await.unwrap();
}));

// poll the run until future once
crate::future::poll_fn(|cx| {
let _ = run_until.as_mut().poll(cx);
Poll::Ready(())
})
.await;

notify.notify_one();
let task = unsafe { local.context.shared.local_queue().pop_front() };
// TODO(eliza): it would be nice to be able to assert that this is
// the local task.
assert!(
task.is_some(),
"task should have been notified to the LocalSet's local queue"
);
})
}
}

0 comments on commit 23a1ccf

Please sign in to comment.