Skip to content

Commit

Permalink
macros: join! start by polling a different future each time poll_fn i…
Browse files Browse the repository at this point in the history
…s polled

Fixes: tokio-rs#4612
  • Loading branch information
PoorlyDefinedBehaviour committed Apr 24, 2022
1 parent c43832a commit 44dda22
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 14 deletions.
123 changes: 123 additions & 0 deletions tokio-macros/src/count.rs
@@ -0,0 +1,123 @@
use proc_macro::TokenStream;
use proc_macro::{Punct, TokenTree};
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Token};

pub(crate) fn count(input: TokenStream) -> TokenStream {
let count: usize = input.into_iter().count();

TokenStream::from(quote!(#count))
}

struct Join {
fut_exprs: Vec<Expr>,
}

impl Parse for Join {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut exprs = Vec::new();

while !input.is_empty() {
exprs.push(input.parse::<Expr>()?);

if !input.is_empty() {
input.parse::<Token![,]>()?;
}
}

Ok(Join { fut_exprs: exprs })
}
}

pub(crate) fn test(input: TokenStream) -> TokenStream {
let parsed = syn::parse_macro_input!(input as Join);

let futures_count = parsed.fut_exprs.len();

let futures = parsed.fut_exprs.into_iter().map(|fut_expr| {
quote! {
maybe_done(#fut_expr)
}
});

// let futures = tokens.map(|fut| {
// quote! {
// $crate::maybe_done(#fut)
// }
// });

let if_statements = (0..futures_count).map(|i| {
quote! {
if #i == turn {
let fut = &mut futures.#1;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

// Try polling
if fut.poll(cx).is_pending() {
is_pending = true;
}

continue;
}
}
});

let build_ready_output = (0..futures_count).map(|i| {
quote! {
let fut = &mut futures.#i;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

fut.take_output().expect("expected completed future")
}
});

TokenStream::from(quote! {
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
use $crate::macros::support::Poll::{Ready, Pending};

// Safety: nothing must be moved out of `futures`. This is to satisfy
// the requirement of `Pin::new_unchecked` called below.
// let mut futures = #futures;
// #( #futures )*

// How many futures were passed to join!.
const FUTURE_COUNT: u32 = #futures_count;

// When poll_fn is polled, start polling the future at this index.
// let mut start_index = 0;

poll_fn(move |cx| {
let mut is_pending = false;

for i in 0..FUTURE_COUNT {
let turn;

#[allow(clippy::modulo_one)]
{
turn = (start_index + i) % FUTURE_COUNT
};

// #( #if_statements )*
}

if is_pending {
// Start by polling the next future first the next time poll_fn is polled
#[allow(clippy::modulo_one)]
{
start_index = (start_index + 1) % FUTURE_COUNT;
}

Pending
} else {
// Ready( #( #build_ready_output )* )
}
}).await
})
}
19 changes: 19 additions & 0 deletions tokio-macros/src/lib.rs
Expand Up @@ -17,6 +17,7 @@
#[allow(unused_extern_crates)]
extern crate proc_macro;

mod count;
mod entry;
mod select;

Expand Down Expand Up @@ -336,3 +337,21 @@ pub fn select_priv_declare_output_enum(input: TokenStream) -> TokenStream {
pub fn select_priv_clean_pattern(input: TokenStream) -> TokenStream {
select::clean_pattern_macro(input)
}

#[proc_macro]
#[doc(hidden)]
pub fn count_proc_macro(input: TokenStream) -> TokenStream {
count::count(input)
}

#[proc_macro]
#[doc(hidden)]
pub fn count_proc_macro2(input: TokenStream) -> TokenStream {
count::count(input)
}

#[proc_macro]
#[doc(hidden)]
pub fn join_v2(input: TokenStream) -> TokenStream {
count::test(input)
}
8 changes: 8 additions & 0 deletions tokio/src/lib.rs
Expand Up @@ -528,6 +528,14 @@ cfg_macros! {
#[doc(hidden)]
pub use tokio_macros::select_priv_clean_pattern;

/// Implementation detail of the `join!` macro. This macro is **not**
/// intended to be used as part of the public API and is permitted to
/// change.
#[doc(hidden)]
pub use tokio_macros::count_proc_macro;
pub use tokio_macros::count_proc_macro2;
pub use tokio_macros::join_v2;

cfg_rt! {
#[cfg(feature = "rt-multi-thread")]
#[cfg(not(test))] // Work around for rust-lang/rust#62127
Expand Down
46 changes: 35 additions & 11 deletions tokio/src/macros/join.rs
Expand Up @@ -71,24 +71,48 @@ macro_rules! join {
// the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( maybe_done($e), )* );

// How many futures were passed to join!.
const FUTURE_COUNT: u32 = $crate::count_proc_macro!( $($count)* ) as u32;

// When poll_fn is polled, start polling the future at this index.
let mut start_index = 0;

poll_fn(move |cx| {
let mut is_pending = false;

$(
// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;
for i in 0..FUTURE_COUNT {
let turn;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };
#[allow(clippy::modulo_one)]
{
turn = (start_index + i) % FUTURE_COUNT
};

// Try polling
if fut.poll(cx).is_pending() {
is_pending = true;
}
)*
$(
if $crate::count_proc_macro!( $($skip)* ) as u32 == turn {
let ( $($skip,)* fut, .. ) = &mut futures;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

// Try polling
if fut.poll(cx).is_pending() {
is_pending = true;
}

continue;
}
)*
}

if is_pending {
// Start by polling the next future first the next time poll_fn is polled
#[allow(clippy::modulo_one)]
{
start_index = (start_index + 1) % FUTURE_COUNT;
}

Pending
} else {
Ready(($({
Expand Down

0 comments on commit 44dda22

Please sign in to comment.