Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream select macro #2262

Merged
merged 10 commits into from Jul 31, 2021
10 changes: 10 additions & 0 deletions futures-macro/src/lib.rs
Expand Up @@ -19,6 +19,7 @@ use proc_macro::TokenStream;
mod executor;
mod join;
mod select;
mod stream_select;

/// The `join!` macro.
#[cfg_attr(fn_like_proc_macro, proc_macro)]
Expand Down Expand Up @@ -54,3 +55,12 @@ pub fn select_biased_internal(input: TokenStream) -> TokenStream {
pub fn test_internal(input: TokenStream, item: TokenStream) -> TokenStream {
crate::executor::test(input, item)
}

/// The `stream_select!` macro.
#[cfg_attr(fn_like_proc_macro, proc_macro)]
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack)]
pub fn stream_select_internal(input: TokenStream) -> TokenStream {
crate::stream_select::stream_select(input.into())
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
108 changes: 108 additions & 0 deletions futures-macro/src/stream_select.rs
@@ -0,0 +1,108 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};

/// The `stream_select!` macro.
pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
let args = args.iter().map(|e| e.to_token_stream());

Ok(quote! {
{
#[derive(Debug)]
struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);

enum StreamEnum<#(#generic_idents),*> {
#(
#generic_idents(#generic_idents)
),*,
None,
}

impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
{
type Item = ITEM;

fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
match self.get_mut() {
#(
Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
),*,
Self::None => panic!("StreamEnum::None should never be polled!"),
}
}
}

impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
{
type Item = ITEM;

fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
let Self(#(ref mut #field_idents),*) = self.get_mut();
#(
let mut #field_idents_2 = false;
)*
let mut any_pending = false;
{
let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
__futures_crate::async_await::shuffle(&mut stream_array);

for mut s in stream_array {
if let StreamEnum::None = s {
continue;
} else {
match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
r @ __futures_crate::task::Poll::Ready(Some(_)) => {
return r;
},
__futures_crate::task::Poll::Pending => {
any_pending = true;
},
__futures_crate::task::Poll::Ready(None) => {
match s {
#(
StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
),*,
StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
}
},
}
}
}
}
#(
if #field_idents_2 {
*#field_idents = None;
}
)*
if any_pending {
__futures_crate::task::Poll::Pending
} else {
__futures_crate::task::Poll::Ready(None)
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let mut s = (0, Some(0));
#(
if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
s.0 += new_hint.0;
// We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
}
)*
s
}
}

StreamSelect(#(Some(#args)),*)

}
})
}
7 changes: 7 additions & 0 deletions futures-util/src/async_await/mod.rs
Expand Up @@ -30,6 +30,13 @@ mod select_mod;
#[cfg(feature = "async-await-macro")]
pub use self::select_mod::*;

// Primary export is a macro
#[cfg(feature = "async-await-macro")]
mod stream_select_mod;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/64762
#[cfg(feature = "async-await-macro")]
pub use self::stream_select_mod::*;

#[cfg(feature = "std")]
#[cfg(feature = "async-await-macro")]
mod random;
Expand Down
46 changes: 46 additions & 0 deletions futures-util/src/async_await/stream_select_mod.rs
@@ -0,0 +1,46 @@
//! The `stream_select` macro.

#[cfg(feature = "std")]
#[allow(unreachable_pub)]
#[doc(hidden)]
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack(support_nested))]
pub use futures_macro::stream_select_internal;

/// Combines several streams, all producing the same `Item` type, into one stream.
/// This is similar to `select_all` but does not require the streams to all be the same type.
/// It also keeps the streams inline, and does not require `Box<dyn Stream>`s to be allocated.
/// Streams passed to this macro must be `Unpin`.
///
/// Fairness for this stream is implemented in terms of the futures `select` macro. If multiple
/// streams are ready, one will be pseudo randomly selected at runtime.
///
/// This macro is gated behind the `async-await` feature of this library, which is activated by default.
/// Note that `stream_select!` relies on `proc-macro-hack`, and may require to set the compiler's recursion
/// limit very high, e.g. `#![recursion_limit="1024"]`.
///
/// # Examples
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::{stream, StreamExt, stream_select};
/// let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse();
///
/// let mut endless_numbers = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
/// match endless_numbers.next().await {
/// Some(1) => println!("Got a 1"),
/// Some(2) => println!("Got a 2"),
/// Some(3) => println!("Got a 3"),
/// _ => unreachable!(),
/// }
/// # });
/// ```
#[cfg(feature = "std")]
#[macro_export]
macro_rules! stream_select {
($($tokens:tt)*) => {{
use $crate::__private as __futures_crate;
$crate::stream_select_internal! {
$( $tokens )*
}
}}
}
5 changes: 5 additions & 0 deletions futures/src/lib.rs
Expand Up @@ -137,6 +137,11 @@ pub use futures_util::{join, pending, poll, select_biased, try_join}; // Async-a
#[doc(inline)]
pub use futures_util::{future, sink, stream, task};

#[cfg(feature = "std")]
#[cfg(feature = "async-await")]
pub use futures_util::stream_select;

#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
#[cfg(feature = "alloc")]
#[doc(inline)]
pub use futures_channel as channel;
Expand Down
40 changes: 39 additions & 1 deletion futures/tests/async_await_macros.rs
Expand Up @@ -4,7 +4,9 @@ use futures::future::{self, poll_fn, FutureExt};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::task::{Context, Poll};
use futures::{join, pending, pin_mut, poll, select, select_biased, try_join};
use futures::{
join, pending, pin_mut, poll, select, select_biased, stream, stream_select, try_join,
};
use std::mem;

#[test]
Expand Down Expand Up @@ -308,6 +310,42 @@ fn select_on_mutable_borrowing_future_with_same_borrow_in_block_and_default() {
});
}

#[test]
#[allow(unused_assignments)]
fn stream_select() {
// stream_select! macro
block_on(async {
let endless_ints = |i| stream::iter(vec![i].into_iter().cycle());

let mut endless_ones = stream_select!(endless_ints(1i32), stream::pending());
assert_eq!(endless_ones.next().await, Some(1));
assert_eq!(endless_ones.next().await, Some(1));

let mut finite_list = stream_select!(stream::iter(vec![1, 2, 3].into_iter()));
assert_eq!(finite_list.next().await, Some(1));
assert_eq!(finite_list.next().await, Some(2));
assert_eq!(finite_list.next().await, Some(3));
assert_eq!(finite_list.next().await, None);
taiki-e marked this conversation as resolved.
Show resolved Hide resolved

let endless_mixed = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
// Take 1000, and assert a somewhat even distribution of values.
// The fairness is randomized, but over 1000 samples we should be pretty close to even.
// This test may be a bit flaky. Feel free to adjust the margins as you see fit.
let mut count = 0;
let results = endless_mixed
.take_while(move |_| {
count += 1;
let ret = count < 1000;
async move { ret }
})
.collect::<Vec<_>>()
.await;
assert!(results.iter().filter(|x| **x == 1).count() >= 299);
assert!(results.iter().filter(|x| **x == 2).count() >= 299);
assert!(results.iter().filter(|x| **x == 3).count() >= 299);
});
}

#[test]
fn join_size() {
let fut = async {
Expand Down