diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index f91b5f19145..866af5fa9de 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -60,8 +60,11 @@ macro_rules! join { // normalization is complete. ( $($count:tt)* ) + // The expression `0+1+1+ ... +1` equal to the number of branches. + ( $($total:tt)* ) + // Normalized join! branches - $( ( $($skip:tt)* ) $e:expr, )* + $( ( $($skip:tt)* ) ( $($branch_index:tt)* ) $e:expr, )* }) => {{ use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; @@ -71,24 +74,52 @@ macro_rules! join { // the requirement of `Pin::new_unchecked` called below. let mut futures = ( $( maybe_done($e), )* ); + // When poll_fn is polled, start polling the future at this index. + let mut start_index = 0; + poll_fn(move |cx| { + const FUTURE_COUNT: u32 = $($total)*; + let mut is_pending = false; - $( - // Extract the future for this branch from the tuple. - let ( $($skip,)* fut, .. ) = &mut futures; + let mut turn = start_index; + + for _ in 0..FUTURE_COUNT { + $( + { + const INDEX: u32 = $($branch_index)*; + if turn == INDEX { + let ( $($skip,)* fut, .. ) = &mut futures; - // Safety: future is stored on the stack above - // and never moved. - let mut fut = unsafe { Pin::new_unchecked(fut) }; + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; - // Try polling - if fut.poll(cx).is_pending() { - is_pending = true; - } - )* + // Try polling + if fut.poll(cx).is_pending() { + is_pending = true; + } + + turn = if turn + 1 == FUTURE_COUNT { + 0 + } else { + turn + 1 + }; + + continue; + } + } + )* + } if is_pending { + // Start by polling the next future first the next time poll_fn is polled + start_index = if start_index + 1 == FUTURE_COUNT { + 0 + } else { + start_index + 1 + }; + Pending } else { Ready(($({ @@ -107,13 +138,13 @@ macro_rules! join { // ===== Normalize ===== - (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) ($($n)*) $e, } $($r)*) }; // ===== Entry point ===== ( $($e:expr),* $(,)?) => { - $crate::join!(@{ () } $($e,)*) + $crate::join!(@{ () (0) } $($e,)*) }; } diff --git a/tokio/src/macros/try_join.rs b/tokio/src/macros/try_join.rs index 6d3a893b7e1..44b0453a8ad 100644 --- a/tokio/src/macros/try_join.rs +++ b/tokio/src/macros/try_join.rs @@ -106,8 +106,11 @@ macro_rules! try_join { // normalization is complete. ( $($count:tt)* ) + // The expression `0+1+1+ ... +1` equal to the number of branches. + ( $($total:tt)* ) + // Normalized try_join! branches - $( ( $($skip:tt)* ) $e:expr, )* + $( ( $($skip:tt)* ) ( $($branch_index:tt)* ) $e:expr, )* }) => {{ use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; @@ -117,26 +120,55 @@ macro_rules! try_join { // the requirement of `Pin::new_unchecked` called below. let mut futures = ( $( maybe_done($e), )* ); + // When poll_fn is polled, start polling the future at this index. + let mut start_index = 0; + poll_fn(move |cx| { + const FUTURE_COUNT: u32 = $($total)*; + let mut is_pending = false; - $( - // Extract the future for this branch from the tuple. - let ( $($skip,)* fut, .. ) = &mut futures; + let mut turn = start_index; + + for _ in 0..FUTURE_COUNT { + $( + { + const INDEX: u32 = $($branch_index)*; + if turn == INDEX { + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; - // Safety: future is stored on the stack above - // and never moved. - let mut fut = unsafe { Pin::new_unchecked(fut) }; + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; - // Try polling - if fut.as_mut().poll(cx).is_pending() { - is_pending = true; - } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { - return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) - } - )* + // Try polling + if fut.as_mut().poll(cx).is_pending() { + is_pending = true; + } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { + return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) + } + + turn = if turn + 1 == FUTURE_COUNT { + 0 + } else { + turn + 1 + }; + + continue; + } + } + )* + } if is_pending { + // Start by polling the next future first the next time poll_fn is polled + start_index = if start_index + 1 == FUTURE_COUNT { + 0 + } else { + start_index + 1 + }; + Pending } else { Ready(Ok(($({ @@ -159,13 +191,13 @@ macro_rules! try_join { // ===== Normalize ===== - (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::try_join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) ($($n)+) $e, } $($r)*) }; // ===== Entry point ===== ( $($e:expr),* $(,)?) => { - $crate::try_join!(@{ () } $($e,)*) + $crate::try_join!(@{ () (0) } $($e,)*) }; } diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d4f20b3862c..25dc4ad3f8e 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,6 +1,8 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] +use std::{sync::Arc, time::Duration}; + #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; #[cfg(target_arch = "wasm32")] @@ -9,7 +11,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(target_arch = "wasm32"))] use tokio::test as maybe_tokio_test; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[maybe_tokio_test] @@ -71,12 +73,50 @@ fn join_size() { let ready = future::ready(0i32); tokio::join!(ready) }; - assert_eq!(mem::size_of_val(&fut), 16); + assert_eq!(mem::size_of_val(&fut), 20); let fut = async { let ready1 = future::ready(0i32); let ready2 = future::ready(0i32); tokio::join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 28); + assert_eq!(mem::size_of_val(&fut), 32); +} + +async fn non_cooperative_task(permits: Arc) -> usize { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + exceeded_budget +} + +async fn poor_little_task() -> usize { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(100)).await; + how_many_times_i_got_to_run += 1; + } + + how_many_times_i_got_to_run +} + +#[tokio::test] +async fn join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(10)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let (non_cooperative_result, little_task_result) = + tokio::join!(non_cooperative_task(permits), poor_little_task()); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); } diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs index 60a726b659a..f5bb4acdf86 100644 --- a/tokio/tests/macros_try_join.rs +++ b/tokio/tests/macros_try_join.rs @@ -1,7 +1,9 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] -use tokio::sync::oneshot; +use std::{sync::Arc, time::Duration}; + +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[cfg(target_arch = "wasm32")] @@ -94,16 +96,55 @@ fn join_size() { let ready = future::ready(ok(0i32)); tokio::try_join!(ready) }; - assert_eq!(mem::size_of_val(&fut), 16); + assert_eq!(mem::size_of_val(&fut), 20); let fut = async { let ready1 = future::ready(ok(0i32)); let ready2 = future::ready(ok(0i32)); tokio::try_join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 28); + assert_eq!(mem::size_of_val(&fut), 2328); } fn ok(val: T) -> Result { Ok(val) } + +async fn non_cooperative_task(permits: Arc) -> Result { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + Ok(exceeded_budget) +} + +async fn poor_little_task() -> Result { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(100)).await; + how_many_times_i_got_to_run += 1; + } + + Ok(how_many_times_i_got_to_run) +} + +#[tokio::test] +async fn try_join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(10)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let result = tokio::try_join!(non_cooperative_task(permits), poor_little_task()); + + let (non_cooperative_result, little_task_result) = result.unwrap(); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); +}