From 22862739dddd49a94065aa7a917cde2dc8a3f6bc Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Wed, 30 Nov 2022 14:21:08 -0800 Subject: [PATCH] rt: yield_now defers task until after driver poll (#5223) Previously, calling `task::yield_now().await` would yield the current task to the scheduler, but the scheduler would poll it again before polling the resource drivers. This behavior can result in starving the resource drivers. This patch creates a queue tracking yielded tasks. The scheduler notifies those tasks **after** polling the resource drivers. Refs: #5209 --- .github/workflows/loom.yml | 1 + tokio/src/runtime/context.rs | 30 ++++++- tokio/src/runtime/defer.rs | 27 +++++++ tokio/src/runtime/mod.rs | 3 + tokio/src/runtime/park.rs | 21 +++++ tokio/src/runtime/scheduler/current_thread.rs | 18 ++++- .../runtime/scheduler/multi_thread/worker.rs | 16 +++- tokio/src/runtime/tests/loom_yield.rs | 37 +++++++++ tokio/src/runtime/tests/mod.rs | 1 + tokio/src/task/yield_now.rs | 14 +++- tokio/tests/rt_common.rs | 79 ++++++++++++++++++- 11 files changed, 240 insertions(+), 7 deletions(-) create mode 100644 tokio/src/runtime/defer.rs create mode 100644 tokio/src/runtime/tests/loom_yield.rs diff --git a/.github/workflows/loom.yml b/.github/workflows/loom.yml index d9bc4d161e7..bd92a1e45a4 100644 --- a/.github/workflows/loom.yml +++ b/.github/workflows/loom.yml @@ -45,4 +45,5 @@ jobs: env: RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings LOOM_MAX_PREEMPTIONS: 2 + LOOM_MAX_BRANCHES: 10000 SCOPE: ${{ matrix.scope }} diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 0e7b636af05..4f30d3374a9 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -7,8 +7,7 @@ use std::cell::Cell; use crate::util::rand::{FastRand, RngSeed}; cfg_rt! { - use crate::runtime::scheduler; - use crate::runtime::task::Id; + use crate::runtime::{scheduler, task::Id, Defer}; use std::cell::RefCell; use std::marker::PhantomData; @@ -19,6 +18,7 @@ struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] handle: RefCell>, + #[cfg(feature = "rt")] current_task_id: Cell>, @@ -30,6 +30,11 @@ struct Context { #[cfg(feature = "rt")] runtime: Cell, + /// Yielded task wakers are stored here and notified after resource drivers + /// are polled. + #[cfg(feature = "rt")] + defer: RefCell>, + #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand, @@ -56,6 +61,9 @@ tokio_thread_local! { #[cfg(feature = "rt")] runtime: Cell::new(EnterRuntime::NotEntered), + #[cfg(feature = "rt")] + defer: RefCell::new(None), + #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand::new(RngSeed::new()), @@ -159,7 +167,12 @@ cfg_rt! { if c.runtime.get().is_entered() { None } else { + // Set the entered flag c.runtime.set(EnterRuntime::Entered { allow_block_in_place }); + + // Initialize queue to track yielded tasks + *c.defer.borrow_mut() = Some(Defer::new()); + Some(EnterRuntimeGuard { blocking: BlockingRegionGuard::new(), handle: c.set_current(handle), @@ -201,6 +214,14 @@ cfg_rt! { DisallowBlockInPlaceGuard(reset) } + pub(crate) fn with_defer(f: impl FnOnce(&mut Defer) -> R) -> Option { + CONTEXT.with(|c| { + let mut defer = c.defer.borrow_mut(); + + defer.as_mut().map(f) + }) + } + impl Context { fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard { let rng_seed = handle.seed_generator().next_seed(); @@ -235,6 +256,7 @@ cfg_rt! { CONTEXT.with(|c| { assert!(c.runtime.get().is_entered()); c.runtime.set(EnterRuntime::NotEntered); + *c.defer.borrow_mut() = None; }); } } @@ -286,6 +308,10 @@ cfg_rt! { return Err(()); } + // Wake any yielded tasks before parking in order to avoid + // blocking. + with_defer(|defer| defer.wake()); + park.park_timeout(when - now); } } diff --git a/tokio/src/runtime/defer.rs b/tokio/src/runtime/defer.rs new file mode 100644 index 00000000000..4078512de7a --- /dev/null +++ b/tokio/src/runtime/defer.rs @@ -0,0 +1,27 @@ +use std::task::Waker; + +pub(crate) struct Defer { + deferred: Vec, +} + +impl Defer { + pub(crate) fn new() -> Defer { + Defer { + deferred: Default::default(), + } + } + + pub(crate) fn defer(&mut self, waker: Waker) { + self.deferred.push(waker); + } + + pub(crate) fn is_empty(&self) -> bool { + self.deferred.is_empty() + } + + pub(crate) fn wake(&mut self) { + for waker in self.deferred.drain(..) { + waker.wake(); + } + } +} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 6e527801ec6..45b79b0ac81 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -228,6 +228,9 @@ cfg_rt! { pub use crate::util::rand::RngSeed; } + mod defer; + pub(crate) use defer::Defer; + mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/park.rs b/tokio/src/runtime/park.rs index 9dfbfae0951..dc86f42a966 100644 --- a/tokio/src/runtime/park.rs +++ b/tokio/src/runtime/park.rs @@ -32,6 +32,12 @@ tokio_thread_local! { static CURRENT_PARKER: ParkThread = ParkThread::new(); } +// Bit of a hack, but it is only for loom +#[cfg(loom)] +tokio_thread_local! { + static CURRENT_THREAD_PARK_COUNT: AtomicUsize = AtomicUsize::new(0); +} + // ==== impl ParkThread ==== impl ParkThread { @@ -51,10 +57,15 @@ impl ParkThread { } pub(crate) fn park(&mut self) { + #[cfg(loom)] + CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst)); self.inner.park(); } pub(crate) fn park_timeout(&mut self, duration: Duration) { + #[cfg(loom)] + CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst)); + // Wasm doesn't have threads, so just sleep. #[cfg(not(tokio_wasm))] self.inner.park_timeout(duration); @@ -273,6 +284,11 @@ impl CachedParkThread { return Ok(v); } + // Wake any yielded tasks before parking in order to avoid + // blocking. + #[cfg(feature = "rt")] + crate::runtime::context::with_defer(|defer| defer.wake()); + self.park(); } } @@ -330,3 +346,8 @@ unsafe fn wake_by_ref(raw: *const ()) { // We don't actually own a reference to the unparker mem::forget(unparker); } + +#[cfg(loom)] +pub(crate) fn current_thread_park_count() -> usize { + CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst)) +} diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index d874448c55e..375e47c412b 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -3,7 +3,7 @@ use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::{Arc, Mutex}; use crate::runtime::driver::{self, Driver}; use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; -use crate::runtime::{blocking, scheduler, Config}; +use crate::runtime::{blocking, context, scheduler, Config}; use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; @@ -267,6 +267,14 @@ impl Core { } } +fn did_defer_tasks() -> bool { + context::with_defer(|deferred| !deferred.is_empty()).unwrap() +} + +fn wake_deferred_tasks() { + context::with_defer(|deferred| deferred.wake()); +} + // ===== impl Context ===== impl Context { @@ -299,6 +307,7 @@ impl Context { let (c, _) = self.enter(core, || { driver.park(&handle.driver); + wake_deferred_tasks(); }); core = c; @@ -324,6 +333,7 @@ impl Context { core.metrics.submit(&handle.shared.worker_metrics); let (mut core, _) = self.enter(core, || { driver.park_timeout(&handle.driver, Duration::from_millis(0)); + wake_deferred_tasks(); }); core.driver = Some(driver); @@ -557,7 +567,11 @@ impl CoreGuard<'_> { let task = match entry { Some(entry) => entry, None => { - core = context.park(core, handle); + core = if did_defer_tasks() { + context.park_yield(core, handle) + } else { + context.park(core, handle) + }; // Try polling the `block_on` future next continue 'outer; diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index ee0765230c2..ce6c313b105 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -412,7 +412,11 @@ impl Context { core = self.run_task(task, core)?; } else { // Wait for work - core = self.park(core); + core = if did_defer_tasks() { + self.park_timeout(core, Some(Duration::from_millis(0))) + } else { + self.park(core) + }; } } @@ -535,6 +539,8 @@ impl Context { park.park(&self.worker.handle.driver); } + wake_deferred_tasks(); + // Remove `core` from context core = self.core.borrow_mut().take().expect("core missing"); @@ -853,6 +859,14 @@ impl Handle { } } +fn did_defer_tasks() -> bool { + context::with_defer(|deferred| !deferred.is_empty()).unwrap() +} + +fn wake_deferred_tasks() { + context::with_defer(|deferred| deferred.wake()); +} + cfg_metrics! { impl Shared { pub(super) fn injection_queue_depth(&self) -> usize { diff --git a/tokio/src/runtime/tests/loom_yield.rs b/tokio/src/runtime/tests/loom_yield.rs new file mode 100644 index 00000000000..ba506e5a408 --- /dev/null +++ b/tokio/src/runtime/tests/loom_yield.rs @@ -0,0 +1,37 @@ +use crate::runtime::park; +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::{self, Runtime}; + +#[test] +fn yield_calls_park_before_scheduling_again() { + // Don't need to check all permutations + let mut loom = loom::model::Builder::default(); + loom.max_permutations = Some(1); + loom.check(|| { + let rt = mk_runtime(2); + let (tx, rx) = oneshot::channel::<()>(); + + rt.spawn(async { + let tid = loom::thread::current().id(); + let park_count = park::current_thread_park_count(); + + crate::task::yield_now().await; + + if tid == loom::thread::current().id() { + let new_park_count = park::current_thread_park_count(); + assert_eq!(park_count + 1, new_park_count); + } + + tx.send(()); + }); + + rx.recv(); + }); +} + +fn mk_runtime(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index b4b8cb45844..1c67dfefb32 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -41,6 +41,7 @@ cfg_loom! { mod loom_queue; mod loom_shutdown_join; mod loom_join_set; + mod loom_yield; } cfg_not_loom! { diff --git a/tokio/src/task/yield_now.rs b/tokio/src/task/yield_now.rs index 148e3dc0c87..7b61dd86bf0 100644 --- a/tokio/src/task/yield_now.rs +++ b/tokio/src/task/yield_now.rs @@ -1,3 +1,5 @@ +use crate::runtime::context; + use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -49,7 +51,17 @@ pub async fn yield_now() { } self.yielded = true; - cx.waker().wake_by_ref(); + + let defer = context::with_defer(|rt| { + rt.defer(cx.waker().clone()); + }); + + if defer.is_none() { + // Not currently in a runtime, just notify ourselves + // immediately. + cx.waker().wake_by_ref(); + } + Poll::Pending } } diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 433eb2444b8..ef0c2a222dc 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -9,6 +9,9 @@ macro_rules! rt_test { mod current_thread_scheduler { $($t)* + #[cfg(not(target_os="wasi"))] + const NUM_WORKERS: usize = 1; + fn rt() -> Arc { tokio::runtime::Builder::new_current_thread() .enable_all() @@ -22,6 +25,8 @@ macro_rules! rt_test { mod threaded_scheduler_4_threads { $($t)* + const NUM_WORKERS: usize = 4; + fn rt() -> Arc { tokio::runtime::Builder::new_multi_thread() .worker_threads(4) @@ -36,6 +41,8 @@ macro_rules! rt_test { mod threaded_scheduler_1_thread { $($t)* + const NUM_WORKERS: usize = 1; + fn rt() -> Arc { tokio::runtime::Builder::new_multi_thread() .worker_threads(1) @@ -652,7 +659,12 @@ rt_test! { for _ in 0..100 { rt.spawn(async { loop { - tokio::task::yield_now().await; + // Don't use Tokio's `yield_now()` to avoid special defer + // logic. + let _: () = futures::future::poll_fn(|cx| { + cx.waker().wake_by_ref(); + std::task::Poll::Pending + }).await; } }); } @@ -680,6 +692,71 @@ rt_test! { }); } + /// Tests that yielded tasks are not scheduled until **after** resource + /// drivers are polled. + /// + /// Note: we may have to delete this test as it is not necessarily reliable. + /// The OS does not guarantee when I/O events are delivered, so there may be + /// more yields than anticipated. + #[test] + #[cfg(not(target_os="wasi"))] + fn yield_defers_until_park() { + use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; + use std::sync::Barrier; + + let rt = rt(); + + let flag = Arc::new(AtomicBool::new(false)); + let barrier = Arc::new(Barrier::new(NUM_WORKERS)); + + rt.block_on(async { + // Make sure other workers cannot steal tasks + #[allow(clippy::reversed_empty_ranges)] + for _ in 0..(NUM_WORKERS-1) { + let flag = flag.clone(); + let barrier = barrier.clone(); + + tokio::spawn(async move { + barrier.wait(); + + while !flag.load(SeqCst) { + std::thread::sleep(std::time::Duration::from_millis(1)); + } + }); + } + + barrier.wait(); + + tokio::spawn(async move { + // Create a TCP litener + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::join!( + async { + // Done blocking intentionally + let _socket = std::net::TcpStream::connect(addr).unwrap(); + + // Yield until connected + let mut cnt = 0; + while !flag.load(SeqCst){ + tokio::task::yield_now().await; + cnt += 1; + + if cnt >= 10 { + panic!("yielded too many times; TODO: delete this test?"); + } + } + }, + async { + let _ = listener.accept().await.unwrap(); + flag.store(true, SeqCst); + } + ); + }).await.unwrap(); + }); + } + #[cfg(not(target_os="wasi"))] // Wasi does not support threads #[test] fn client_server_block_on() {