diff --git a/.github/workflows/loom.yml b/.github/workflows/loom.yml index d9bc4d161e7..bd92a1e45a4 100644 --- a/.github/workflows/loom.yml +++ b/.github/workflows/loom.yml @@ -45,4 +45,5 @@ jobs: env: RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings LOOM_MAX_PREEMPTIONS: 2 + LOOM_MAX_BRANCHES: 10000 SCOPE: ${{ matrix.scope }} diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 0e7b636af05..4f30d3374a9 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -7,8 +7,7 @@ use std::cell::Cell; use crate::util::rand::{FastRand, RngSeed}; cfg_rt! { - use crate::runtime::scheduler; - use crate::runtime::task::Id; + use crate::runtime::{scheduler, task::Id, Defer}; use std::cell::RefCell; use std::marker::PhantomData; @@ -19,6 +18,7 @@ struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] handle: RefCell>, + #[cfg(feature = "rt")] current_task_id: Cell>, @@ -30,6 +30,11 @@ struct Context { #[cfg(feature = "rt")] runtime: Cell, + /// Yielded task wakers are stored here and notified after resource drivers + /// are polled. + #[cfg(feature = "rt")] + defer: RefCell>, + #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand, @@ -56,6 +61,9 @@ tokio_thread_local! { #[cfg(feature = "rt")] runtime: Cell::new(EnterRuntime::NotEntered), + #[cfg(feature = "rt")] + defer: RefCell::new(None), + #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand::new(RngSeed::new()), @@ -159,7 +167,12 @@ cfg_rt! { if c.runtime.get().is_entered() { None } else { + // Set the entered flag c.runtime.set(EnterRuntime::Entered { allow_block_in_place }); + + // Initialize queue to track yielded tasks + *c.defer.borrow_mut() = Some(Defer::new()); + Some(EnterRuntimeGuard { blocking: BlockingRegionGuard::new(), handle: c.set_current(handle), @@ -201,6 +214,14 @@ cfg_rt! { DisallowBlockInPlaceGuard(reset) } + pub(crate) fn with_defer(f: impl FnOnce(&mut Defer) -> R) -> Option { + CONTEXT.with(|c| { + let mut defer = c.defer.borrow_mut(); + + defer.as_mut().map(f) + }) + } + impl Context { fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard { let rng_seed = handle.seed_generator().next_seed(); @@ -235,6 +256,7 @@ cfg_rt! { CONTEXT.with(|c| { assert!(c.runtime.get().is_entered()); c.runtime.set(EnterRuntime::NotEntered); + *c.defer.borrow_mut() = None; }); } } @@ -286,6 +308,10 @@ cfg_rt! { return Err(()); } + // Wake any yielded tasks before parking in order to avoid + // blocking. + with_defer(|defer| defer.wake()); + park.park_timeout(when - now); } } diff --git a/tokio/src/runtime/defer.rs b/tokio/src/runtime/defer.rs new file mode 100644 index 00000000000..4078512de7a --- /dev/null +++ b/tokio/src/runtime/defer.rs @@ -0,0 +1,27 @@ +use std::task::Waker; + +pub(crate) struct Defer { + deferred: Vec, +} + +impl Defer { + pub(crate) fn new() -> Defer { + Defer { + deferred: Default::default(), + } + } + + pub(crate) fn defer(&mut self, waker: Waker) { + self.deferred.push(waker); + } + + pub(crate) fn is_empty(&self) -> bool { + self.deferred.is_empty() + } + + pub(crate) fn wake(&mut self) { + for waker in self.deferred.drain(..) { + waker.wake(); + } + } +} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 6e527801ec6..45b79b0ac81 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -228,6 +228,9 @@ cfg_rt! { pub use crate::util::rand::RngSeed; } + mod defer; + pub(crate) use defer::Defer; + mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/park.rs b/tokio/src/runtime/park.rs index 9dfbfae0951..dc86f42a966 100644 --- a/tokio/src/runtime/park.rs +++ b/tokio/src/runtime/park.rs @@ -32,6 +32,12 @@ tokio_thread_local! { static CURRENT_PARKER: ParkThread = ParkThread::new(); } +// Bit of a hack, but it is only for loom +#[cfg(loom)] +tokio_thread_local! { + static CURRENT_THREAD_PARK_COUNT: AtomicUsize = AtomicUsize::new(0); +} + // ==== impl ParkThread ==== impl ParkThread { @@ -51,10 +57,15 @@ impl ParkThread { } pub(crate) fn park(&mut self) { + #[cfg(loom)] + CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst)); self.inner.park(); } pub(crate) fn park_timeout(&mut self, duration: Duration) { + #[cfg(loom)] + CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst)); + // Wasm doesn't have threads, so just sleep. #[cfg(not(tokio_wasm))] self.inner.park_timeout(duration); @@ -273,6 +284,11 @@ impl CachedParkThread { return Ok(v); } + // Wake any yielded tasks before parking in order to avoid + // blocking. + #[cfg(feature = "rt")] + crate::runtime::context::with_defer(|defer| defer.wake()); + self.park(); } } @@ -330,3 +346,8 @@ unsafe fn wake_by_ref(raw: *const ()) { // We don't actually own a reference to the unparker mem::forget(unparker); } + +#[cfg(loom)] +pub(crate) fn current_thread_park_count() -> usize { + CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst)) +} diff --git a/tokio/src/runtime/scheduler/current_thread.rs b/tokio/src/runtime/scheduler/current_thread.rs index d874448c55e..375e47c412b 100644 --- a/tokio/src/runtime/scheduler/current_thread.rs +++ b/tokio/src/runtime/scheduler/current_thread.rs @@ -3,7 +3,7 @@ use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::{Arc, Mutex}; use crate::runtime::driver::{self, Driver}; use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; -use crate::runtime::{blocking, scheduler, Config}; +use crate::runtime::{blocking, context, scheduler, Config}; use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; @@ -267,6 +267,14 @@ impl Core { } } +fn did_defer_tasks() -> bool { + context::with_defer(|deferred| !deferred.is_empty()).unwrap() +} + +fn wake_deferred_tasks() { + context::with_defer(|deferred| deferred.wake()); +} + // ===== impl Context ===== impl Context { @@ -299,6 +307,7 @@ impl Context { let (c, _) = self.enter(core, || { driver.park(&handle.driver); + wake_deferred_tasks(); }); core = c; @@ -324,6 +333,7 @@ impl Context { core.metrics.submit(&handle.shared.worker_metrics); let (mut core, _) = self.enter(core, || { driver.park_timeout(&handle.driver, Duration::from_millis(0)); + wake_deferred_tasks(); }); core.driver = Some(driver); @@ -557,7 +567,11 @@ impl CoreGuard<'_> { let task = match entry { Some(entry) => entry, None => { - core = context.park(core, handle); + core = if did_defer_tasks() { + context.park_yield(core, handle) + } else { + context.park(core, handle) + }; // Try polling the `block_on` future next continue 'outer; diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index ee0765230c2..ce6c313b105 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -412,7 +412,11 @@ impl Context { core = self.run_task(task, core)?; } else { // Wait for work - core = self.park(core); + core = if did_defer_tasks() { + self.park_timeout(core, Some(Duration::from_millis(0))) + } else { + self.park(core) + }; } } @@ -535,6 +539,8 @@ impl Context { park.park(&self.worker.handle.driver); } + wake_deferred_tasks(); + // Remove `core` from context core = self.core.borrow_mut().take().expect("core missing"); @@ -853,6 +859,14 @@ impl Handle { } } +fn did_defer_tasks() -> bool { + context::with_defer(|deferred| !deferred.is_empty()).unwrap() +} + +fn wake_deferred_tasks() { + context::with_defer(|deferred| deferred.wake()); +} + cfg_metrics! { impl Shared { pub(super) fn injection_queue_depth(&self) -> usize { diff --git a/tokio/src/runtime/tests/loom_yield.rs b/tokio/src/runtime/tests/loom_yield.rs new file mode 100644 index 00000000000..ba506e5a408 --- /dev/null +++ b/tokio/src/runtime/tests/loom_yield.rs @@ -0,0 +1,37 @@ +use crate::runtime::park; +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::{self, Runtime}; + +#[test] +fn yield_calls_park_before_scheduling_again() { + // Don't need to check all permutations + let mut loom = loom::model::Builder::default(); + loom.max_permutations = Some(1); + loom.check(|| { + let rt = mk_runtime(2); + let (tx, rx) = oneshot::channel::<()>(); + + rt.spawn(async { + let tid = loom::thread::current().id(); + let park_count = park::current_thread_park_count(); + + crate::task::yield_now().await; + + if tid == loom::thread::current().id() { + let new_park_count = park::current_thread_park_count(); + assert_eq!(park_count + 1, new_park_count); + } + + tx.send(()); + }); + + rx.recv(); + }); +} + +fn mk_runtime(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index b4b8cb45844..1c67dfefb32 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -41,6 +41,7 @@ cfg_loom! { mod loom_queue; mod loom_shutdown_join; mod loom_join_set; + mod loom_yield; } cfg_not_loom! { diff --git a/tokio/src/task/yield_now.rs b/tokio/src/task/yield_now.rs index 148e3dc0c87..7b61dd86bf0 100644 --- a/tokio/src/task/yield_now.rs +++ b/tokio/src/task/yield_now.rs @@ -1,3 +1,5 @@ +use crate::runtime::context; + use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -49,7 +51,17 @@ pub async fn yield_now() { } self.yielded = true; - cx.waker().wake_by_ref(); + + let defer = context::with_defer(|rt| { + rt.defer(cx.waker().clone()); + }); + + if defer.is_none() { + // Not currently in a runtime, just notify ourselves + // immediately. + cx.waker().wake_by_ref(); + } + Poll::Pending } } diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 433eb2444b8..ef0c2a222dc 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -9,6 +9,9 @@ macro_rules! rt_test { mod current_thread_scheduler { $($t)* + #[cfg(not(target_os="wasi"))] + const NUM_WORKERS: usize = 1; + fn rt() -> Arc { tokio::runtime::Builder::new_current_thread() .enable_all() @@ -22,6 +25,8 @@ macro_rules! rt_test { mod threaded_scheduler_4_threads { $($t)* + const NUM_WORKERS: usize = 4; + fn rt() -> Arc { tokio::runtime::Builder::new_multi_thread() .worker_threads(4) @@ -36,6 +41,8 @@ macro_rules! rt_test { mod threaded_scheduler_1_thread { $($t)* + const NUM_WORKERS: usize = 1; + fn rt() -> Arc { tokio::runtime::Builder::new_multi_thread() .worker_threads(1) @@ -652,7 +659,12 @@ rt_test! { for _ in 0..100 { rt.spawn(async { loop { - tokio::task::yield_now().await; + // Don't use Tokio's `yield_now()` to avoid special defer + // logic. + let _: () = futures::future::poll_fn(|cx| { + cx.waker().wake_by_ref(); + std::task::Poll::Pending + }).await; } }); } @@ -680,6 +692,71 @@ rt_test! { }); } + /// Tests that yielded tasks are not scheduled until **after** resource + /// drivers are polled. + /// + /// Note: we may have to delete this test as it is not necessarily reliable. + /// The OS does not guarantee when I/O events are delivered, so there may be + /// more yields than anticipated. + #[test] + #[cfg(not(target_os="wasi"))] + fn yield_defers_until_park() { + use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; + use std::sync::Barrier; + + let rt = rt(); + + let flag = Arc::new(AtomicBool::new(false)); + let barrier = Arc::new(Barrier::new(NUM_WORKERS)); + + rt.block_on(async { + // Make sure other workers cannot steal tasks + #[allow(clippy::reversed_empty_ranges)] + for _ in 0..(NUM_WORKERS-1) { + let flag = flag.clone(); + let barrier = barrier.clone(); + + tokio::spawn(async move { + barrier.wait(); + + while !flag.load(SeqCst) { + std::thread::sleep(std::time::Duration::from_millis(1)); + } + }); + } + + barrier.wait(); + + tokio::spawn(async move { + // Create a TCP litener + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::join!( + async { + // Done blocking intentionally + let _socket = std::net::TcpStream::connect(addr).unwrap(); + + // Yield until connected + let mut cnt = 0; + while !flag.load(SeqCst){ + tokio::task::yield_now().await; + cnt += 1; + + if cnt >= 10 { + panic!("yielded too many times; TODO: delete this test?"); + } + } + }, + async { + let _ = listener.accept().await.unwrap(); + flag.store(true, SeqCst); + } + ); + }).await.unwrap(); + }); + } + #[cfg(not(target_os="wasi"))] // Wasi does not support threads #[test] fn client_server_block_on() {