Skip to content

Commit

Permalink
Add stream select macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Xaeroxe committed Nov 20, 2020
1 parent 98e4dfc commit 4ddcdf1
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 2 deletions.
7 changes: 7 additions & 0 deletions futures-macro/src/lib.rs
Expand Up @@ -25,6 +25,7 @@ use proc_macro_hack::proc_macro_hack;

mod join;
mod select;
mod stream_select;

/// The `join!` macro.
#[proc_macro_hack]
Expand All @@ -49,3 +50,9 @@ pub fn select_internal(input: TokenStream) -> TokenStream {
pub fn select_biased_internal(input: TokenStream) -> TokenStream {
crate::select::select_biased(input)
}

/// The `stream_select!` macro.
#[proc_macro_hack]
pub fn stream_select_internal(input: TokenStream) -> TokenStream {
crate::stream_select::stream_select(input)
}
97 changes: 97 additions & 0 deletions futures-macro/src/stream_select.rs
@@ -0,0 +1,97 @@
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{Expr, Index, parse::Parser, punctuated::Punctuated, Token};

/// The `stream_select!` macro.
pub(crate) fn stream_select(input: TokenStream) -> TokenStream {
let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse(input).expect("macro expects a comma separated list of expressions");
let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();

TokenStream::from(quote! {
{
struct StreamSelect<#(#generic_idents),*> (#(#generic_idents),*);

enum StreamFutures<#(#generic_idents),*> {
#(
#generic_idents(#generic_idents)
),*
}

impl<OUTPUT, #(#generic_idents),*> __futures_crate::future::Future for StreamFutures<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::future::Future<Output=OUTPUT> + __futures_crate::future::FusedFuture + ::std::marker::Unpin,)*
{
type Output = OUTPUT;

fn poll(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Self::Output> {
match self.get_mut() {
#(
Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll(cx)
),*
}
}
}

impl<OUTPUT, #(#generic_idents),*> __futures_crate::future::FusedFuture for StreamFutures<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::future::Future<Output=OUTPUT> + __futures_crate::future::FusedFuture + ::std::marker::Unpin,)*
{
fn is_terminated(&self) -> bool {
match self {
#(
Self::#generic_idents(#generic_idents) => __futures_crate::future::FusedFuture::is_terminated(#generic_idents)
),*
}
}
}

impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + __futures_crate::stream::FusedStream + ::std::marker::Unpin,)*
{
type Item = ITEM;

fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
let Self(#(ref mut #field_idents),*) = self.get_mut();
let mut future_array = [#(StreamFutures::#generic_idents(#field_idents.next())),*];
__futures_crate::async_await::shuffle(&mut future_array);
let mut any_pending = false;
for f in &mut future_array {
if __futures_crate::future::FusedFuture::is_terminated(f) {
continue;
} else {
match __futures_crate::future::Future::poll(::std::pin::Pin::new(f), cx) {
r @ __futures_crate::task::Poll::Ready(Some(_)) => {
return r;
},
__futures_crate::task::Poll::Pending => {
any_pending = true;
},
__futures_crate::task::Poll::Ready(None) => {},
}
}
}
if any_pending {
__futures_crate::task::Poll::Pending
} else {
__futures_crate::task::Poll::Ready(None)
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let mut s = (0, Some(0));
#(
{
let new_hint = self.#field_indices.size_hint();
s.0 += new_hint.0;
// We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
}
)*
s
}
}

StreamSelect(#args)
}
})
}
7 changes: 7 additions & 0 deletions futures-util/src/async_await/mod.rs
Expand Up @@ -30,6 +30,13 @@ mod select_mod;
#[cfg(feature = "async-await-macro")]
pub use self::select_mod::*;

// Primary export is a macro
#[cfg(feature = "async-await-macro")]
mod stream_select_mod;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/64762
#[cfg(feature = "async-await-macro")]
pub use self::stream_select_mod::*;

#[cfg(feature = "std")]
#[cfg(feature = "async-await-macro")]
mod random;
Expand Down
49 changes: 49 additions & 0 deletions futures-util/src/async_await/stream_select_mod.rs
@@ -0,0 +1,49 @@
//! The `stream_select` macro.

#[cfg(feature = "std")]
use proc_macro_hack::proc_macro_hack;

#[cfg(feature = "std")]
#[doc(hidden)]
#[proc_macro_hack(support_nested, only_hack_old_rustc)]
pub use futures_macro::stream_select_internal;

/// Combines several streams, all producing the same `Item` type, into one stream.
/// This is similar to `select_all` but does not require the streams to all be the same type.
/// It also keeps the streams inline, and does not require `Box<dyn Stream>`s to be allocated.
/// Streams passed to this macro must be `Unpin` and implement `FusedStream`.
///
/// Fairness for this stream is implemented in terms of the futures `select` macro. If multiple
/// streams are ready, one will be pseudo randomly selected at runtime. Streams which are not
/// already fused can be fused by using the `.fuse()` method.
///
/// This macro is gated behind the `async-await` feature of this library, which is activated by default.
/// Note that `stream_select!` relies on `proc-macro-hack`, and may require to set the compiler's recursion
/// limit very high, e.g. `#![recursion_limit="1024"]`.
///
/// # Examples
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::{stream, StreamExt, stream_select};
/// let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse();
///
/// let mut endless_numbers = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
/// match endless_numbers.next().await {
/// Some(1) => println!("Got a 1"),
/// Some(2) => println!("Got a 2"),
/// Some(3) => println!("Got a 3"),
/// _ => unreachable!(),
/// }
/// # });
/// ```
#[cfg(feature = "std")]
#[macro_export]
macro_rules! stream_select {
($($tokens:tt)*) => {{
use $crate::__private as __futures_crate;
$crate::stream_select_internal! {
$( $tokens )*
}
}}
}
4 changes: 4 additions & 0 deletions futures/src/lib.rs
Expand Up @@ -131,6 +131,10 @@ pub use futures_util::{pending, poll, join, try_join, select_biased}; // Async-a
#[cfg(feature = "async-await")]
pub use futures_util::select;

#[cfg(feature = "std")]
#[cfg(feature = "async-await")]
pub use futures_util::stream_select;

#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
#[cfg(feature = "alloc")]
pub mod channel {
Expand Down
35 changes: 33 additions & 2 deletions futures/tests/macro-tests/src/main.rs
@@ -1,7 +1,7 @@
// Check that it works even if proc-macros are reexported.

fn main() {
use futures03::{executor::block_on, future};
use futures03::{executor::block_on, future, stream, StreamExt};

// join! macro
let _ = block_on(async {
Expand Down Expand Up @@ -66,4 +66,35 @@ fn main() {
};
});

}
// stream_select! macro
let _ = block_on(async {
let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse();

let mut endless_ones = futures03::stream_select!(endless_ints(1i32), stream::pending().fuse());
assert_eq!(endless_ones.next().await, Some(1));
assert_eq!(endless_ones.next().await, Some(1));

let mut finite_list = futures03::stream_select!(stream::iter(vec![1, 2, 3].into_iter()).fuse());
assert_eq!(finite_list.next().await, Some(1));
assert_eq!(finite_list.next().await, Some(2));
assert_eq!(finite_list.next().await, Some(3));
assert_eq!(finite_list.next().await, None);

let endless_mixed = futures03::stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
// Take 100, and assert a somewhat even distribution of values.
// The fairness is randomized, but over 100 samples we should be pretty close to even.
// This test may be a bit flaky. Feel free to adjust the margins as you see fit.
let mut count = 0;
let results = endless_mixed.take_while(move |_| {
count += 1;
let ret = count < 100;
async move { ret }
})
.collect::<Vec<_>>()
.await;
assert!(results.iter().filter(|x| **x == 1).count() >= 29);
assert!(results.iter().filter(|x| **x == 2).count() >= 29);
assert!(results.iter().filter(|x| **x == 3).count() >= 29);
});

}

0 comments on commit 4ddcdf1

Please sign in to comment.