diff --git a/tokio/src/coop.rs b/tokio/src/coop.rs index 96905319b65..453ef14791e 100644 --- a/tokio/src/coop.rs +++ b/tokio/src/coop.rs @@ -68,7 +68,7 @@ impl Budget { /// Runs the given closure with a cooperative task budget. When the function /// returns, the budget is reset to the value prior to calling the function. #[inline(always)] -pub(crate) fn budget(f: impl FnOnce() -> R) -> R { +pub fn budget(f: impl FnOnce() -> R) -> R { with_budget(Budget::initial(), f) } diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 27d4dc83855..ac6516e0f46 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -423,7 +423,12 @@ cfg_rt! { pub mod runtime; } -pub(crate) mod coop; +// Includes the `budget` function which is used by the `join` macros. +// +// This module is not intended to be part of the public API. In general, any +// `doc(hidden)` code is not part of Tokio's public and stable API. +#[doc(hidden)] +pub mod coop; cfg_signal! { pub mod signal; diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index f91b5f19145..f97cdef1859 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -82,8 +82,11 @@ macro_rules! join { // and never moved. let mut fut = unsafe { Pin::new_unchecked(fut) }; - // Try polling - if fut.poll(cx).is_pending() { + // Try polling. + // Give each future its own budget to avoid starvation in the case + // where some of the futures consume a resource that is always ready. + let poll = $crate::coop::budget(|| { fut.poll(cx) }); + if poll.is_pending() { is_pending = true; } )* diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d4f20b3862c..1a388553951 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,6 +1,8 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] +use std::{sync::Arc, time::Duration}; + #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; #[cfg(target_arch = "wasm32")] @@ -9,7 +11,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(target_arch = "wasm32"))] use tokio::test as maybe_tokio_test; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[maybe_tokio_test] @@ -80,3 +82,41 @@ fn join_size() { }; assert_eq!(mem::size_of_val(&fut), 28); } + +async fn non_cooperative_task(permits: Arc) -> usize { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + exceeded_budget +} + +async fn poor_little_task() -> usize { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(100)).await; + how_many_times_i_got_to_run += 1; + } + + how_many_times_i_got_to_run +} + +#[tokio::test] +async fn join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(10)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let (non_cooperative_result, little_task_result) = + tokio::join!(non_cooperative_task(permits), poor_little_task()); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); +}