From b7812c85ca2d051d47cec023b880cbf8cdcbc313 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Thu, 10 Nov 2022 10:06:25 -0800 Subject: [PATCH] rt: fix `LocalSet` drop in thread local (#5179) `LocalSet` cleans up any tasks that have not yet been completed when it is dropped. Previously, this cleanup process required access to a thread-local. Suppose a `LocalSet` is stored in a thread-local itself. In that case, when it is dropped, there is no guarantee the drop implementation will be able to access the internal `LocalSet` thread-local as it may already have been destroyed. The internal `LocalSet` thread local is mainly used to avoid writing unsafe code. All `LocalState` that cannot be moved across threads is stored in the thread-local and accessed on demand. This patch moves this local-only state into the `LocalSet`'s "shared" struct. Because this struct *is* `Send`, the local-only state is stored in `UnsafeCell`, and callers must ensure not to touch it from other threads. A debug assertion is added to enforce this requirement in tests. Fixes #5162 --- tokio/src/runtime/context.rs | 6 +- tokio/src/runtime/coop.rs | 38 +++--- tokio/src/task/local.rs | 191 +++++++++++++++++++++++-------- tokio/src/util/mod.rs | 3 - tokio/src/util/vec_deque_cell.rs | 53 --------- tokio/tests/task_local_set.rs | 42 +++++++ 6 files changed, 208 insertions(+), 125 deletions(-) delete mode 100644 tokio/src/util/vec_deque_cell.rs diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 6de657481e0..60bff239dde 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,3 +1,4 @@ +use crate::loom::thread::AccessError; use crate::runtime::coop; use std::cell::Cell; @@ -63,12 +64,11 @@ pub(crate) fn thread_rng_n(n: u32) -> u32 { CONTEXT.with(|ctx| ctx.rng.fastrand_n(n)) } -pub(super) fn budget(f: impl FnOnce(&Cell) -> R) -> R { - CONTEXT.with(|ctx| f(&ctx.budget)) +pub(super) fn budget(f: impl FnOnce(&Cell) -> R) -> Result { + CONTEXT.try_with(|ctx| f(&ctx.budget)) } cfg_rt! { - use crate::loom::thread::AccessError; use crate::runtime::TryCurrentError; use std::fmt; diff --git a/tokio/src/runtime/coop.rs b/tokio/src/runtime/coop.rs index d174abb4ab1..0ba137ab67a 100644 --- a/tokio/src/runtime/coop.rs +++ b/tokio/src/runtime/coop.rs @@ -31,8 +31,6 @@ use crate::runtime::context; -use std::cell::Cell; - /// Opaque type tracking the amount of "work" a task may still do before /// yielding back to the scheduler. #[derive(Debug, Copy, Clone)] @@ -79,37 +77,42 @@ pub(crate) fn with_unconstrained(f: impl FnOnce() -> R) -> R { #[inline(always)] fn with_budget(budget: Budget, f: impl FnOnce() -> R) -> R { - struct ResetGuard<'a> { - cell: &'a Cell, + struct ResetGuard { prev: Budget, } - impl<'a> Drop for ResetGuard<'a> { + impl Drop for ResetGuard { fn drop(&mut self) { - self.cell.set(self.prev); + let _ = context::budget(|cell| { + cell.set(self.prev); + }); } } - context::budget(|cell| { + #[allow(unused_variables)] + let maybe_guard = context::budget(|cell| { let prev = cell.get(); - cell.set(budget); - let _guard = ResetGuard { cell, prev }; + ResetGuard { prev } + }); - f() - }) + // The function is called regardless even if the budget is not successfully + // set due to the thread-local being destroyed. + f() } #[inline(always)] pub(crate) fn has_budget_remaining() -> bool { - context::budget(|cell| cell.get().has_remaining()) + // If the current budget cannot be accessed due to the thread-local being + // shutdown, then we assume there is budget remaining. + context::budget(|cell| cell.get().has_remaining()).unwrap_or(true) } cfg_rt_multi_thread! { /// Sets the current task's budget. pub(crate) fn set(budget: Budget) { - context::budget(|cell| cell.set(budget)) + let _ = context::budget(|cell| cell.set(budget)); } } @@ -122,11 +125,12 @@ cfg_rt! { let prev = cell.get(); cell.set(Budget::unconstrained()); prev - }) + }).unwrap_or(Budget::unconstrained()) } } cfg_coop! { + use std::cell::Cell; use std::task::{Context, Poll}; #[must_use] @@ -144,7 +148,7 @@ cfg_coop! { // They are both represented as the remembered budget being unconstrained. let budget = self.0.get(); if !budget.is_unconstrained() { - context::budget(|cell| { + let _ = context::budget(|cell| { cell.set(budget); }); } @@ -176,7 +180,7 @@ cfg_coop! { cx.waker().wake_by_ref(); Poll::Pending } - }) + }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained())))) } impl Budget { @@ -209,7 +213,7 @@ mod test { use wasm_bindgen_test::wasm_bindgen_test as test; fn get() -> Budget { - context::budget(|cell| cell.get()) + context::budget(|cell| cell.get()).unwrap_or(Budget::unconstrained()) } #[test] diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 38ed22b4f6b..e4a198bd053 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,9 +1,10 @@ //! Runs `!Send` futures on the current thread. +use crate::loom::cell::UnsafeCell; use crate::loom::sync::{Arc, Mutex}; use crate::loom::thread::{self, ThreadId}; use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; -use crate::util::{RcCell, VecDequeCell}; +use crate::util::RcCell; use std::cell::Cell; use std::collections::VecDeque; @@ -226,9 +227,6 @@ cfg_rt! { /// State available from the thread-local. struct Context { - /// Collection of all active tasks spawned onto this executor. - owned: LocalOwnedTasks>, - /// State shared between threads. shared: Arc, @@ -239,18 +237,11 @@ struct Context { /// LocalSet state shared between threads. struct Shared { - /// Local run queue sender and receiver. - /// /// # Safety /// /// This field must *only* be accessed from the thread that owns the /// `LocalSet` (i.e., `Thread::current().id() == owner`). - local_queue: VecDequeCell>>, - - /// The `ThreadId` of the thread that owns the `LocalSet`. - /// - /// Since `LocalSet` is `!Send`, this will never change. - owner: ThreadId, + local_state: LocalState, /// Remote run queue sender. queue: Mutex>>>>, @@ -263,6 +254,19 @@ struct Shared { pub(crate) unhandled_panic: crate::runtime::UnhandledPanic, } +/// Tracks the `LocalSet` state that must only be accessed from the thread that +/// created the `LocalSet`. +struct LocalState { + /// The `ThreadId` of the thread that owns the `LocalSet`. + owner: ThreadId, + + /// Local run queue sender and receiver. + local_queue: UnsafeCell>>>, + + /// Collection of all active tasks spawned onto this executor. + owned: LocalOwnedTasks>, +} + pin_project! { #[derive(Debug)] struct RunUntil<'a, F> { @@ -378,10 +382,12 @@ impl LocalSet { LocalSet { tick: Cell::new(0), context: Rc::new(Context { - owned: LocalOwnedTasks::new(), shared: Arc::new(Shared { - local_queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), - owner: thread::current().id(), + local_state: LocalState { + owner: thread_id().expect("cannot create LocalSet during thread shutdown"), + owned: LocalOwnedTasks::new(), + local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + }, queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), #[cfg(tokio_unstable)] @@ -641,7 +647,12 @@ impl LocalSet { }) }; - task.map(|task| self.context.owned.assert_owner(task)) + task.map(|task| unsafe { + // Safety: because the `LocalSet` itself is `!Send`, we know we are + // on the same thread if we have access to the `LocalSet`, and can + // therefore access the local run queue. + self.context.shared.local_state.assert_owner(task) + }) } fn pop_local(&self) -> Option>> { @@ -649,7 +660,7 @@ impl LocalSet { // Safety: because the `LocalSet` itself is `!Send`, we know we are // on the same thread if we have access to the `LocalSet`, and can // therefore access the local run queue. - self.context.shared.local_queue().pop_front() + self.context.shared.local_state.task_pop_front() } } @@ -796,7 +807,10 @@ impl Future for LocalSet { // there are still tasks remaining in the run queue. cx.waker().wake_by_ref(); Poll::Pending - } else if self.context.owned.is_empty() { + + // Safety: called from the thread that owns `LocalSet`. Because + // `LocalSet` is `!Send`, this is safe. + } else if unsafe { self.context.shared.local_state.owned_is_empty() } { // If the scheduler has no remaining futures, we're done! Poll::Ready(()) } else { @@ -819,7 +833,10 @@ impl Drop for LocalSet { self.with_if_possible(|| { // Shut down all tasks in the LocalOwnedTasks and close it to // prevent new tasks from ever being added. - self.context.owned.close_and_shutdown_all(); + unsafe { + // Safety: called from the thread that owns `LocalSet` + self.context.shared.local_state.close_and_shutdown_all(); + } // We already called shutdown on all tasks above, so there is no // need to call shutdown. @@ -836,7 +853,10 @@ impl Drop for LocalSet { // the local queue in `Drop`, because the `LocalSet` itself is // `!Send`, so we can reasonably guarantee that it will not be // `Drop`ped from another thread. - let local_queue = self.context.shared.local_queue.take(); + let local_queue = unsafe { + // Safety: called from the thread that owns `LocalSet` + self.context.shared.local_state.take_local_queue() + }; for task in local_queue { drop(task); } @@ -848,7 +868,8 @@ impl Drop for LocalSet { drop(task); } - assert!(self.context.owned.is_empty()); + // Safety: called from the thread that owns `LocalSet` + assert!(unsafe { self.context.shared.local_state.owned_is_empty() }); }); } } @@ -865,7 +886,14 @@ impl Context { let id = crate::runtime::task::Id::next(); let future = crate::util::trace::task(future, "local", name, id.as_u64()); - let (handle, notified) = self.owned.bind(future, self.shared.clone(), id); + // Safety: called from the thread that owns the `LocalSet` + let (handle, notified) = { + self.shared.local_state.assert_called_from_owner_thread(); + self.shared + .local_state + .owned + .bind(future, self.shared.clone(), id) + }; if let Some(notified) = notified { self.shared.schedule(notified); @@ -909,21 +937,6 @@ impl Future for RunUntil<'_, T> { } impl Shared { - /// # Safety - /// - /// This is safe to call if and ONLY if we are on the thread that owns this - /// `LocalSet`. - unsafe fn local_queue(&self) -> &VecDequeCell>> { - debug_assert!( - // if we couldn't get the thread ID because we're dropping the local - // data, skip the assertion --- the `Drop` impl is not going to be - // called from another thread, because `LocalSet` is `!Send` - thread_id().map(|id| id == self.owner).unwrap_or(true), - "`LocalSet`'s local run queue must not be accessed by another thread!" - ); - &self.local_queue - } - /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { CURRENT.with(|localdata| { @@ -931,16 +944,16 @@ impl Shared { Some(cx) if cx.shared.ptr_eq(self) => unsafe { // Safety: if the current `LocalSet` context points to this // `LocalSet`, then we are on the thread that owns it. - cx.shared.local_queue().push_back(task); + cx.shared.local_state.task_push_back(task); }, // We are on the thread that owns the `LocalSet`, so we can // wake to the local queue. - _ if localdata.get_or_insert_id() == self.owner => { + _ if localdata.get_id() == Some(self.local_state.owner) => { unsafe { // Safety: we just checked that the thread ID matches // the localset's owner, so this is safe. - self.local_queue().push_back(task); + self.local_state.task_push_back(task); } // We still have to wake the `LocalSet`, because it isn't // currently being polled. @@ -976,13 +989,8 @@ unsafe impl Sync for Shared {} impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { - CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { - None => panic!("scheduler context missing"), - Some(cx) => { - assert!(cx.shared.ptr_eq(self)); - cx.owned.remove(task) - } - }) + // Safety, this is always called from the thread that owns `LocalSet` + unsafe { self.local_state.task_remove(task) } } fn schedule(&self, task: task::Notified) { @@ -1004,7 +1012,8 @@ impl task::Schedule for Arc { CURRENT.with(|LocalData { ctx, .. }| match ctx.get() { Some(cx) if Arc::ptr_eq(self, &cx.shared) => { cx.unhandled_panic.set(true); - cx.owned.close_and_shutdown_all(); + // Safety: this is always called from the thread that owns `LocalSet` + unsafe { cx.shared.local_state.close_and_shutdown_all(); } } _ => unreachable!("runtime core not set in CURRENT thread-local"), }) @@ -1014,7 +1023,91 @@ impl task::Schedule for Arc { } } +impl LocalState { + unsafe fn task_pop_front(&self) -> Option>> { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.local_queue.with_mut(|ptr| (*ptr).pop_front()) + } + + unsafe fn task_push_back(&self, task: task::Notified>) { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.local_queue.with_mut(|ptr| (*ptr).push_back(task)) + } + + unsafe fn take_local_queue(&self) -> VecDeque>> { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.local_queue.with_mut(|ptr| std::mem::take(&mut (*ptr))) + } + + unsafe fn task_remove(&self, task: &Task>) -> Option>> { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.owned.remove(task) + } + + /// Returns true if the `LocalSet` does not have any spawned tasks + unsafe fn owned_is_empty(&self) -> bool { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.owned.is_empty() + } + + unsafe fn assert_owner( + &self, + task: task::Notified>, + ) -> task::LocalNotified> { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.owned.assert_owner(task) + } + + unsafe fn close_and_shutdown_all(&self) { + // The caller ensures it is called from the same thread that owns + // the LocalSet. + self.assert_called_from_owner_thread(); + + self.owned.close_and_shutdown_all() + } + + #[track_caller] + fn assert_called_from_owner_thread(&self) { + // FreeBSD has some weirdness around thread-local destruction. + // TODO: remove this hack when thread id is cleaned up + #[cfg(not(any(target_os = "openbsd", target_os = "freebsd")))] + debug_assert!( + // if we couldn't get the thread ID because we're dropping the local + // data, skip the assertion --- the `Drop` impl is not going to be + // called from another thread, because `LocalSet` is `!Send` + thread_id().map(|id| id == self.owner).unwrap_or(true), + "`LocalSet`'s local run queue must not be accessed by another thread!" + ); + } +} + +// This is `Send` because it is stored in `Shared`. It is up to the caller to +// ensure they are on the same thread that owns the `LocalSet`. +unsafe impl Send for LocalState {} + impl LocalData { + fn get_id(&self) -> Option { + self.thread_id.get() + } + fn get_or_insert_id(&self) -> ThreadId { self.thread_id.get().unwrap_or_else(|| { let id = thread::current().id(); @@ -1089,7 +1182,7 @@ mod tests { .await; notify.notify_one(); - let task = unsafe { local.context.shared.local_queue().pop_front() }; + let task = unsafe { local.context.shared.local_state.task_pop_front() }; // TODO(eliza): it would be nice to be able to assert that this is // the local task. assert!( diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index 3948ed84a0c..245e64de6b4 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -59,9 +59,6 @@ cfg_rt! { mod sync_wrapper; pub(crate) use sync_wrapper::SyncWrapper; - mod vec_deque_cell; - pub(crate) use vec_deque_cell::VecDequeCell; - mod rc_cell; pub(crate) use rc_cell::RcCell; } diff --git a/tokio/src/util/vec_deque_cell.rs b/tokio/src/util/vec_deque_cell.rs deleted file mode 100644 index b4e124c1519..00000000000 --- a/tokio/src/util/vec_deque_cell.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::loom::cell::UnsafeCell; - -use std::collections::VecDeque; -use std::marker::PhantomData; - -/// This type is like VecDeque, except that it is not Sync and can be modified -/// through immutable references. -pub(crate) struct VecDequeCell { - inner: UnsafeCell>, - _not_sync: PhantomData<*const ()>, -} - -// This is Send for the same reasons that RefCell> is Send. -unsafe impl Send for VecDequeCell {} - -impl VecDequeCell { - pub(crate) fn with_capacity(cap: usize) -> Self { - Self { - inner: UnsafeCell::new(VecDeque::with_capacity(cap)), - _not_sync: PhantomData, - } - } - - /// Safety: This method may not be called recursively. - #[inline] - unsafe fn with_inner(&self, f: F) -> R - where - F: FnOnce(&mut VecDeque) -> R, - { - // safety: This type is not Sync, so concurrent calls of this method - // cannot happen. Furthermore, the caller guarantees that the method is - // not called recursively. Finally, this is the only place that can - // create mutable references to the inner VecDeque. This ensures that - // any mutable references created here are exclusive. - self.inner.with_mut(|ptr| f(&mut *ptr)) - } - - pub(crate) fn pop_front(&self) -> Option { - unsafe { self.with_inner(VecDeque::pop_front) } - } - - pub(crate) fn push_back(&self, item: T) { - unsafe { - self.with_inner(|inner| inner.push_back(item)); - } - } - - /// Replaces the inner VecDeque with an empty VecDeque and return the current - /// contents. - pub(crate) fn take(&self) -> VecDeque { - unsafe { self.with_inner(|inner| std::mem::take(inner)) } - } -} diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 271afb8f5cf..1d3a8153381 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -566,6 +566,48 @@ async fn spawn_wakes_localset() { } } +#[test] +fn store_local_set_in_thread_local_with_runtime() { + use tokio::runtime::Runtime; + + thread_local! { + static CURRENT: RtAndLocalSet = RtAndLocalSet::new(); + } + + struct RtAndLocalSet { + rt: Runtime, + local: LocalSet, + } + + impl RtAndLocalSet { + fn new() -> RtAndLocalSet { + RtAndLocalSet { + rt: tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(), + local: LocalSet::new(), + } + } + + async fn inner_method(&self) { + self.local + .run_until(async move { + tokio::task::spawn_local(async {}); + }) + .await + } + + fn method(&self) { + self.rt.block_on(self.inner_method()); + } + } + + CURRENT.with(|f| { + f.method(); + }); +} + #[cfg(tokio_unstable)] mod unstable { use tokio::runtime::UnhandledPanic;