Skip to content

Commit

Permalink
Add stream selection early exit
Browse files Browse the repository at this point in the history
  • Loading branch information
414owen committed May 10, 2022
1 parent 2e30ec3 commit bbbddcc
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 20 deletions.
4 changes: 2 additions & 2 deletions futures-util/src/stream/mod.rs
Expand Up @@ -100,10 +100,10 @@ mod poll_immediate;
pub use self::poll_immediate::{poll_immediate, PollImmediate};

mod select;
pub use self::select::{select, Select};
pub use self::select::{select, select_early_exit, Select};

mod select_with_strategy;
pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy};
pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy};

mod unfold;
pub use self::unfold::{unfold, Unfold};
Expand Down
21 changes: 19 additions & 2 deletions futures-util/src/stream/select.rs
@@ -1,5 +1,5 @@
use super::assert_stream;
use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy};
use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy};
use core::pin::Pin;
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
Expand Down Expand Up @@ -45,6 +45,23 @@ pin_project! {
/// # });
/// ```
pub fn select<St1, St2>(stream1: St1, stream2: St2) -> Select<St1, St2>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
{
select_with_exit(stream1, stream2, ExitStrategy::WhenBothFinish)
}

/// Same as `select`, but finishes when either stream finishes
pub fn select_early_exit<St1, St2>(stream1: St1, stream2: St2) -> Select<St1, St2>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
{
select_with_exit(stream1, stream2, ExitStrategy::WhenEitherFinish)
}

fn select_with_exit<St1, St2>(stream1: St1, stream2: St2, exit_strategy: ExitStrategy) -> Select<St1, St2>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
Expand All @@ -54,7 +71,7 @@ where
}

assert_stream::<St1::Item, _>(Select {
inner: select_with_strategy(stream1, stream2, round_robin),
inner: select_with_strategy(stream1, stream2, round_robin, exit_strategy),
})
}

Expand Down
92 changes: 76 additions & 16 deletions futures-util/src/stream/select_with_strategy.rs
Expand Up @@ -36,6 +36,7 @@ impl Default for PollNext {
}
}

#[derive(PartialEq, Eq, Clone, Copy)]
enum InternalState {
Start,
LeftFinished,
Expand All @@ -61,6 +62,29 @@ impl InternalState {
}
}

/// Decides whether to exit when both streams are completed, or only one
/// is completed. If you need to exit when a specific stream has finished,
/// feel free to add a case here.
#[derive(Clone, Copy, Debug)]
pub enum ExitStrategy {
/// Select stream finishes when both substreams finish
WhenBothFinish,
/// Select stream finishes when either substream finishes
WhenEitherFinish,
}

impl ExitStrategy {
#[inline]
fn is_finished(self, state: InternalState) -> bool {
match (state, self) {
(InternalState::BothFinished, _) => true,
(InternalState::Start, ExitStrategy::WhenEitherFinish) => false,
(_, ExitStrategy::WhenBothFinish) => false,
_ => true,
}
}
}

pin_project! {
/// Stream for the [`select_with_strategy()`] function. See function docs for details.
#[must_use = "streams do nothing unless polled"]
Expand All @@ -73,6 +97,7 @@ pin_project! {
internal_state: InternalState,
state: State,
clos: Clos,
exit_strategy: ExitStrategy,
}
}

Expand All @@ -95,7 +120,7 @@ pin_project! {
///
/// ```rust
/// # futures::executor::block_on(async {
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt, ExitStrategy };
///
/// let left = repeat(1);
/// let right = repeat(2);
Expand All @@ -106,7 +131,7 @@ pin_project! {
/// // use a function pointer instead of a closure.
/// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left }
///
/// let mut out = select_with_strategy(left, right, prio_left);
/// let mut out = select_with_strategy(left, right, prio_left, ExitStrategy::WhenBothFinish);
///
/// for _ in 0..100 {
/// // Whenever we poll out, we will always get `1`.
Expand All @@ -121,26 +146,54 @@ pin_project! {
///
/// ```rust
/// # futures::executor::block_on(async {
/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
/// use futures::stream::{ repeat, select_with_strategy, FusedStream, PollNext, StreamExt, ExitStrategy };
///
/// let left = repeat(1);
/// let right = repeat(2);
/// // Finishes when both streams finish
/// {
/// let left = repeat(1).take(10);
/// let right = repeat(2);
///
/// let rrobin = |last: &mut PollNext| last.toggle();
/// let rrobin = |last: &mut PollNext| last.toggle();
///
/// let mut out = select_with_strategy(left, right, rrobin);
/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenBothFinish);
///
/// for _ in 0..100 {
/// // We should be alternating now.
/// assert_eq!(1, out.select_next_some().await);
/// assert_eq!(2, out.select_next_some().await);
/// for _ in 0..10 {
/// // We should be alternating now.
/// assert_eq!(1, out.select_next_some().await);
/// assert_eq!(2, out.select_next_some().await);
/// }
/// for _ in 0..100 {
/// // First stream has finished
/// assert_eq!(2, out.select_next_some().await);
/// }
/// assert!(!out.is_terminated());
/// }
///
/// // Finishes when either stream finishes
/// {
/// let left = repeat(1).take(10);
/// let right = repeat(2);
///
/// let rrobin = |last: &mut PollNext| last.toggle();
///
/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenEitherFinish);
///
/// for _ in 0..10 {
/// // We should be alternating now.
/// assert_eq!(1, out.select_next_some().await);
/// assert_eq!(2, out.select_next_some().await);
/// }
/// assert_eq!(None, out.next().await);
/// assert!(out.is_terminated());
/// }
/// # });
/// ```
///
pub fn select_with_strategy<St1, St2, Clos, State>(
stream1: St1,
stream2: St2,
which: Clos,
exit_strategy: ExitStrategy,
) -> SelectWithStrategy<St1, St2, Clos, State>
where
St1: Stream,
Expand All @@ -154,6 +207,7 @@ where
state: Default::default(),
internal_state: InternalState::Start,
clos: which,
exit_strategy,
})
}

Expand Down Expand Up @@ -199,10 +253,7 @@ where
Clos: FnMut(&mut State) -> PollNext,
{
fn is_terminated(&self) -> bool {
match self.internal_state {
InternalState::BothFinished => true,
_ => false,
}
self.exit_strategy.is_finished(self.internal_state)
}
}

Expand All @@ -227,6 +278,7 @@ fn poll_inner<St1, St2, Clos, State>(
select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
side: PollNext,
cx: &mut Context<'_>,
exit_strat: ExitStrategy,
) -> Poll<Option<St1::Item>>
where
St1: Stream,
Expand All @@ -236,6 +288,9 @@ where
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
Poll::Ready(None) => {
select.internal_state.finish(side);
if exit_strat.is_finished(*select.internal_state) {
return Poll::Ready(None);
}
}
Poll::Pending => (),
};
Expand All @@ -259,11 +314,16 @@ where

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
let mut this = self.project();
let exit_strategy: ExitStrategy = *this.exit_strategy;

if exit_strategy.is_finished(*this.internal_state) {
return Poll::Ready(None);
}

match this.internal_state {
InternalState::Start => {
let next_side = (this.clos)(this.state);
poll_inner(&mut this, next_side, cx)
poll_inner(&mut this, next_side, cx, exit_strategy)
}
InternalState::LeftFinished => match this.stream2.poll_next(cx) {
Poll::Ready(None) => {
Expand Down
16 changes: 16 additions & 0 deletions futures/tests/stream.rs
Expand Up @@ -25,6 +25,22 @@ fn select() {
select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]);
}

#[test]
fn select_early_exit() {
fn select_and_compare(a: Vec<u32>, b: Vec<u32>, expected: Vec<u32>) {
let a = stream::iter(a);
let b = stream::iter(b);
let vec = block_on(stream::select_early_exit(a, b).collect::<Vec<_>>());
assert_eq!(vec, expected);
}

select_and_compare(vec![1, 2, 3], vec![4, 5, 6], vec![1, 4, 2, 5, 3, 6]);
select_and_compare(vec![], vec![4, 5], vec![]);
select_and_compare(vec![4, 5], vec![], vec![4]);
select_and_compare(vec![1, 2, 3], vec![4, 5], vec![1, 4, 2, 5, 3]);
select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5]);
}

#[test]
fn flat_map() {
block_on(async {
Expand Down

0 comments on commit bbbddcc

Please sign in to comment.