Skip to content

Commit

Permalink
rt: yield_now defers task until after driver poll (#5223)
Browse files Browse the repository at this point in the history
Previously, calling `task::yield_now().await` would yield the current
task to the scheduler, but the scheduler would poll it again before
polling the resource drivers. This behavior can result in starving the
resource drivers.

This patch creates a queue tracking yielded tasks. The scheduler
notifies those tasks **after** polling the resource drivers.

Refs: #5209
  • Loading branch information
carllerche committed Nov 30, 2022
1 parent 993a60b commit 2286273
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/loom.yml
Expand Up @@ -45,4 +45,5 @@ jobs:
env:
RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings
LOOM_MAX_PREEMPTIONS: 2
LOOM_MAX_BRANCHES: 10000
SCOPE: ${{ matrix.scope }}
30 changes: 28 additions & 2 deletions tokio/src/runtime/context.rs
Expand Up @@ -7,8 +7,7 @@ use std::cell::Cell;
use crate::util::rand::{FastRand, RngSeed};

cfg_rt! {
use crate::runtime::scheduler;
use crate::runtime::task::Id;
use crate::runtime::{scheduler, task::Id, Defer};

use std::cell::RefCell;
use std::marker::PhantomData;
Expand All @@ -19,6 +18,7 @@ struct Context {
/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,

#[cfg(feature = "rt")]
current_task_id: Cell<Option<Id>>,

Expand All @@ -30,6 +30,11 @@ struct Context {
#[cfg(feature = "rt")]
runtime: Cell<EnterRuntime>,

/// Yielded task wakers are stored here and notified after resource drivers
/// are polled.
#[cfg(feature = "rt")]
defer: RefCell<Option<Defer>>,

#[cfg(any(feature = "rt", feature = "macros"))]
rng: FastRand,

Expand All @@ -56,6 +61,9 @@ tokio_thread_local! {
#[cfg(feature = "rt")]
runtime: Cell::new(EnterRuntime::NotEntered),

#[cfg(feature = "rt")]
defer: RefCell::new(None),

#[cfg(any(feature = "rt", feature = "macros"))]
rng: FastRand::new(RngSeed::new()),

Expand Down Expand Up @@ -159,7 +167,12 @@ cfg_rt! {
if c.runtime.get().is_entered() {
None
} else {
// Set the entered flag
c.runtime.set(EnterRuntime::Entered { allow_block_in_place });

// Initialize queue to track yielded tasks
*c.defer.borrow_mut() = Some(Defer::new());

Some(EnterRuntimeGuard {
blocking: BlockingRegionGuard::new(),
handle: c.set_current(handle),
Expand Down Expand Up @@ -201,6 +214,14 @@ cfg_rt! {
DisallowBlockInPlaceGuard(reset)
}

pub(crate) fn with_defer<R>(f: impl FnOnce(&mut Defer) -> R) -> Option<R> {
CONTEXT.with(|c| {
let mut defer = c.defer.borrow_mut();

defer.as_mut().map(f)
})
}

impl Context {
fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard {
let rng_seed = handle.seed_generator().next_seed();
Expand Down Expand Up @@ -235,6 +256,7 @@ cfg_rt! {
CONTEXT.with(|c| {
assert!(c.runtime.get().is_entered());
c.runtime.set(EnterRuntime::NotEntered);
*c.defer.borrow_mut() = None;
});
}
}
Expand Down Expand Up @@ -286,6 +308,10 @@ cfg_rt! {
return Err(());
}

// Wake any yielded tasks before parking in order to avoid
// blocking.
with_defer(|defer| defer.wake());

park.park_timeout(when - now);
}
}
Expand Down
27 changes: 27 additions & 0 deletions tokio/src/runtime/defer.rs
@@ -0,0 +1,27 @@
use std::task::Waker;

pub(crate) struct Defer {
deferred: Vec<Waker>,
}

impl Defer {
pub(crate) fn new() -> Defer {
Defer {
deferred: Default::default(),
}
}

pub(crate) fn defer(&mut self, waker: Waker) {
self.deferred.push(waker);
}

pub(crate) fn is_empty(&self) -> bool {
self.deferred.is_empty()
}

pub(crate) fn wake(&mut self) {
for waker in self.deferred.drain(..) {
waker.wake();
}
}
}
3 changes: 3 additions & 0 deletions tokio/src/runtime/mod.rs
Expand Up @@ -228,6 +228,9 @@ cfg_rt! {
pub use crate::util::rand::RngSeed;
}

mod defer;
pub(crate) use defer::Defer;

mod handle;
pub use handle::{EnterGuard, Handle, TryCurrentError};

Expand Down
21 changes: 21 additions & 0 deletions tokio/src/runtime/park.rs
Expand Up @@ -32,6 +32,12 @@ tokio_thread_local! {
static CURRENT_PARKER: ParkThread = ParkThread::new();
}

// Bit of a hack, but it is only for loom
#[cfg(loom)]
tokio_thread_local! {
static CURRENT_THREAD_PARK_COUNT: AtomicUsize = AtomicUsize::new(0);
}

// ==== impl ParkThread ====

impl ParkThread {
Expand All @@ -51,10 +57,15 @@ impl ParkThread {
}

pub(crate) fn park(&mut self) {
#[cfg(loom)]
CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst));
self.inner.park();
}

pub(crate) fn park_timeout(&mut self, duration: Duration) {
#[cfg(loom)]
CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst));

// Wasm doesn't have threads, so just sleep.
#[cfg(not(tokio_wasm))]
self.inner.park_timeout(duration);
Expand Down Expand Up @@ -273,6 +284,11 @@ impl CachedParkThread {
return Ok(v);
}

// Wake any yielded tasks before parking in order to avoid
// blocking.
#[cfg(feature = "rt")]
crate::runtime::context::with_defer(|defer| defer.wake());

self.park();
}
}
Expand Down Expand Up @@ -330,3 +346,8 @@ unsafe fn wake_by_ref(raw: *const ()) {
// We don't actually own a reference to the unparker
mem::forget(unparker);
}

#[cfg(loom)]
pub(crate) fn current_thread_park_count() -> usize {
CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))
}
18 changes: 16 additions & 2 deletions tokio/src/runtime/scheduler/current_thread.rs
Expand Up @@ -3,7 +3,7 @@ use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::{Arc, Mutex};
use crate::runtime::driver::{self, Driver};
use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task};
use crate::runtime::{blocking, scheduler, Config};
use crate::runtime::{blocking, context, scheduler, Config};
use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics};
use crate::sync::notify::Notify;
use crate::util::atomic_cell::AtomicCell;
Expand Down Expand Up @@ -267,6 +267,14 @@ impl Core {
}
}

fn did_defer_tasks() -> bool {
context::with_defer(|deferred| !deferred.is_empty()).unwrap()
}

fn wake_deferred_tasks() {
context::with_defer(|deferred| deferred.wake());
}

// ===== impl Context =====

impl Context {
Expand Down Expand Up @@ -299,6 +307,7 @@ impl Context {

let (c, _) = self.enter(core, || {
driver.park(&handle.driver);
wake_deferred_tasks();
});

core = c;
Expand All @@ -324,6 +333,7 @@ impl Context {
core.metrics.submit(&handle.shared.worker_metrics);
let (mut core, _) = self.enter(core, || {
driver.park_timeout(&handle.driver, Duration::from_millis(0));
wake_deferred_tasks();
});

core.driver = Some(driver);
Expand Down Expand Up @@ -557,7 +567,11 @@ impl CoreGuard<'_> {
let task = match entry {
Some(entry) => entry,
None => {
core = context.park(core, handle);
core = if did_defer_tasks() {
context.park_yield(core, handle)
} else {
context.park(core, handle)
};

// Try polling the `block_on` future next
continue 'outer;
Expand Down
16 changes: 15 additions & 1 deletion tokio/src/runtime/scheduler/multi_thread/worker.rs
Expand Up @@ -412,7 +412,11 @@ impl Context {
core = self.run_task(task, core)?;
} else {
// Wait for work
core = self.park(core);
core = if did_defer_tasks() {
self.park_timeout(core, Some(Duration::from_millis(0)))
} else {
self.park(core)
};
}
}

Expand Down Expand Up @@ -535,6 +539,8 @@ impl Context {
park.park(&self.worker.handle.driver);
}

wake_deferred_tasks();

// Remove `core` from context
core = self.core.borrow_mut().take().expect("core missing");

Expand Down Expand Up @@ -853,6 +859,14 @@ impl Handle {
}
}

fn did_defer_tasks() -> bool {
context::with_defer(|deferred| !deferred.is_empty()).unwrap()
}

fn wake_deferred_tasks() {
context::with_defer(|deferred| deferred.wake());
}

cfg_metrics! {
impl Shared {
pub(super) fn injection_queue_depth(&self) -> usize {
Expand Down
37 changes: 37 additions & 0 deletions tokio/src/runtime/tests/loom_yield.rs
@@ -0,0 +1,37 @@
use crate::runtime::park;
use crate::runtime::tests::loom_oneshot as oneshot;
use crate::runtime::{self, Runtime};

#[test]
fn yield_calls_park_before_scheduling_again() {
// Don't need to check all permutations
let mut loom = loom::model::Builder::default();
loom.max_permutations = Some(1);
loom.check(|| {
let rt = mk_runtime(2);
let (tx, rx) = oneshot::channel::<()>();

rt.spawn(async {
let tid = loom::thread::current().id();
let park_count = park::current_thread_park_count();

crate::task::yield_now().await;

if tid == loom::thread::current().id() {
let new_park_count = park::current_thread_park_count();
assert_eq!(park_count + 1, new_park_count);
}

tx.send(());
});

rx.recv();
});
}

fn mk_runtime(num_threads: usize) -> Runtime {
runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
.build()
.unwrap()
}
1 change: 1 addition & 0 deletions tokio/src/runtime/tests/mod.rs
Expand Up @@ -41,6 +41,7 @@ cfg_loom! {
mod loom_queue;
mod loom_shutdown_join;
mod loom_join_set;
mod loom_yield;
}

cfg_not_loom! {
Expand Down
14 changes: 13 additions & 1 deletion tokio/src/task/yield_now.rs
@@ -1,3 +1,5 @@
use crate::runtime::context;

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -49,7 +51,17 @@ pub async fn yield_now() {
}

self.yielded = true;
cx.waker().wake_by_ref();

let defer = context::with_defer(|rt| {
rt.defer(cx.waker().clone());
});

if defer.is_none() {
// Not currently in a runtime, just notify ourselves
// immediately.
cx.waker().wake_by_ref();
}

Poll::Pending
}
}
Expand Down

0 comments on commit 2286273

Please sign in to comment.