diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index f91b5f19145..af033f4ee61 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -71,24 +71,50 @@ macro_rules! join { // the requirement of `Pin::new_unchecked` called below. let mut futures = ( $( maybe_done($e), )* ); + // How many futures were passed to join!. + const FUTURE_COUNT: u32 = $crate::count!( $($count)* ); + + // When poll_fn is polled, start polling the future at this index. + let mut start_index = 0; + poll_fn(move |cx| { let mut is_pending = false; - $( - // Extract the future for this branch from the tuple. - let ( $($skip,)* fut, .. ) = &mut futures; + for i in 0..FUTURE_COUNT { + let turn; + + #[allow(clippy::modulo_one)] + { + turn = (start_index + i) % FUTURE_COUNT + }; + + match turn { + $( + #[allow(unreachable_code)] + $crate::count!( $($skip)* ) => { + let ( $($skip,)* fut, .. ) = &mut futures; - // Safety: future is stored on the stack above - // and never moved. - let mut fut = unsafe { Pin::new_unchecked(fut) }; + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; - // Try polling - if fut.poll(cx).is_pending() { - is_pending = true; + // Try polling + if fut.poll(cx).is_pending() { + is_pending = true; + } + } + )* + _ => unreachable!("reaching this means there probably is an off by one bug") } - )* + } if is_pending { + // Start by polling the next future first the next time poll_fn is polled + #[allow(clippy::modulo_one)] + { + start_index = (start_index + 1) % FUTURE_COUNT; + } + Pending } else { Ready(($({ diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d4f20b3862c..25dc4ad3f8e 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,6 +1,8 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] +use std::{sync::Arc, time::Duration}; + #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; #[cfg(target_arch = "wasm32")] @@ -9,7 +11,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(target_arch = "wasm32"))] use tokio::test as maybe_tokio_test; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[maybe_tokio_test] @@ -71,12 +73,50 @@ fn join_size() { let ready = future::ready(0i32); tokio::join!(ready) }; - assert_eq!(mem::size_of_val(&fut), 16); + assert_eq!(mem::size_of_val(&fut), 20); let fut = async { let ready1 = future::ready(0i32); let ready2 = future::ready(0i32); tokio::join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 28); + assert_eq!(mem::size_of_val(&fut), 32); +} + +async fn non_cooperative_task(permits: Arc) -> usize { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + exceeded_budget +} + +async fn poor_little_task() -> usize { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(100)).await; + how_many_times_i_got_to_run += 1; + } + + how_many_times_i_got_to_run +} + +#[tokio::test] +async fn join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(10)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let (non_cooperative_result, little_task_result) = + tokio::join!(non_cooperative_task(permits), poor_little_task()); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); }