From bbbddccd553c26bcc24aecf12edde7776da04069 Mon Sep 17 00:00:00 2001 From: Owen Shepherd Date: Fri, 25 Mar 2022 15:47:12 +0000 Subject: [PATCH] Add stream selection early exit --- futures-util/src/stream/mod.rs | 4 +- futures-util/src/stream/select.rs | 21 ++++- .../src/stream/select_with_strategy.rs | 92 +++++++++++++++---- futures/tests/stream.rs | 16 ++++ 4 files changed, 113 insertions(+), 20 deletions(-) diff --git a/futures-util/src/stream/mod.rs b/futures-util/src/stream/mod.rs index 5a1f766aaa..2914d1c4b6 100644 --- a/futures-util/src/stream/mod.rs +++ b/futures-util/src/stream/mod.rs @@ -100,10 +100,10 @@ mod poll_immediate; pub use self::poll_immediate::{poll_immediate, PollImmediate}; mod select; -pub use self::select::{select, Select}; +pub use self::select::{select, select_early_exit, Select}; mod select_with_strategy; -pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy}; +pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy}; mod unfold; pub use self::unfold::{unfold, Unfold}; diff --git a/futures-util/src/stream/select.rs b/futures-util/src/stream/select.rs index 0c1e3af782..2c61157523 100644 --- a/futures-util/src/stream/select.rs +++ b/futures-util/src/stream/select.rs @@ -1,5 +1,5 @@ use super::assert_stream; -use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy}; +use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy, ExitStrategy}; use core::pin::Pin; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; @@ -45,6 +45,23 @@ pin_project! { /// # }); /// ``` pub fn select(stream1: St1, stream2: St2) -> Select +where + St1: Stream, + St2: Stream, +{ + select_with_exit(stream1, stream2, ExitStrategy::WhenBothFinish) +} + +/// Same as `select`, but finishes when either stream finishes +pub fn select_early_exit(stream1: St1, stream2: St2) -> Select +where + St1: Stream, + St2: Stream, +{ + select_with_exit(stream1, stream2, ExitStrategy::WhenEitherFinish) +} + +fn select_with_exit(stream1: St1, stream2: St2, exit_strategy: ExitStrategy) -> Select where St1: Stream, St2: Stream, @@ -54,7 +71,7 @@ where } assert_stream::(Select { - inner: select_with_strategy(stream1, stream2, round_robin), + inner: select_with_strategy(stream1, stream2, round_robin, exit_strategy), }) } diff --git a/futures-util/src/stream/select_with_strategy.rs b/futures-util/src/stream/select_with_strategy.rs index 37dc5fe338..623dbbfd56 100644 --- a/futures-util/src/stream/select_with_strategy.rs +++ b/futures-util/src/stream/select_with_strategy.rs @@ -36,6 +36,7 @@ impl Default for PollNext { } } +#[derive(PartialEq, Eq, Clone, Copy)] enum InternalState { Start, LeftFinished, @@ -61,6 +62,29 @@ impl InternalState { } } +/// Decides whether to exit when both streams are completed, or only one +/// is completed. If you need to exit when a specific stream has finished, +/// feel free to add a case here. +#[derive(Clone, Copy, Debug)] +pub enum ExitStrategy { + /// Select stream finishes when both substreams finish + WhenBothFinish, + /// Select stream finishes when either substream finishes + WhenEitherFinish, +} + +impl ExitStrategy { + #[inline] + fn is_finished(self, state: InternalState) -> bool { + match (state, self) { + (InternalState::BothFinished, _) => true, + (InternalState::Start, ExitStrategy::WhenEitherFinish) => false, + (_, ExitStrategy::WhenBothFinish) => false, + _ => true, + } + } +} + pin_project! { /// Stream for the [`select_with_strategy()`] function. See function docs for details. #[must_use = "streams do nothing unless polled"] @@ -73,6 +97,7 @@ pin_project! { internal_state: InternalState, state: State, clos: Clos, + exit_strategy: ExitStrategy, } } @@ -95,7 +120,7 @@ pin_project! { /// /// ```rust /// # futures::executor::block_on(async { -/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt }; +/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt, ExitStrategy }; /// /// let left = repeat(1); /// let right = repeat(2); @@ -106,7 +131,7 @@ pin_project! { /// // use a function pointer instead of a closure. /// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left } /// -/// let mut out = select_with_strategy(left, right, prio_left); +/// let mut out = select_with_strategy(left, right, prio_left, ExitStrategy::WhenBothFinish); /// /// for _ in 0..100 { /// // Whenever we poll out, we will always get `1`. @@ -121,26 +146,54 @@ pin_project! { /// /// ```rust /// # futures::executor::block_on(async { -/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt }; +/// use futures::stream::{ repeat, select_with_strategy, FusedStream, PollNext, StreamExt, ExitStrategy }; /// -/// let left = repeat(1); -/// let right = repeat(2); +/// // Finishes when both streams finish +/// { +/// let left = repeat(1).take(10); +/// let right = repeat(2); /// -/// let rrobin = |last: &mut PollNext| last.toggle(); +/// let rrobin = |last: &mut PollNext| last.toggle(); /// -/// let mut out = select_with_strategy(left, right, rrobin); +/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenBothFinish); /// -/// for _ in 0..100 { -/// // We should be alternating now. -/// assert_eq!(1, out.select_next_some().await); -/// assert_eq!(2, out.select_next_some().await); +/// for _ in 0..10 { +/// // We should be alternating now. +/// assert_eq!(1, out.select_next_some().await); +/// assert_eq!(2, out.select_next_some().await); +/// } +/// for _ in 0..100 { +/// // First stream has finished +/// assert_eq!(2, out.select_next_some().await); +/// } +/// assert!(!out.is_terminated()); +/// } +/// +/// // Finishes when either stream finishes +/// { +/// let left = repeat(1).take(10); +/// let right = repeat(2); +/// +/// let rrobin = |last: &mut PollNext| last.toggle(); +/// +/// let mut out = select_with_strategy(left, right, rrobin, ExitStrategy::WhenEitherFinish); +/// +/// for _ in 0..10 { +/// // We should be alternating now. +/// assert_eq!(1, out.select_next_some().await); +/// assert_eq!(2, out.select_next_some().await); +/// } +/// assert_eq!(None, out.next().await); +/// assert!(out.is_terminated()); /// } /// # }); /// ``` +/// pub fn select_with_strategy( stream1: St1, stream2: St2, which: Clos, + exit_strategy: ExitStrategy, ) -> SelectWithStrategy where St1: Stream, @@ -154,6 +207,7 @@ where state: Default::default(), internal_state: InternalState::Start, clos: which, + exit_strategy, }) } @@ -199,10 +253,7 @@ where Clos: FnMut(&mut State) -> PollNext, { fn is_terminated(&self) -> bool { - match self.internal_state { - InternalState::BothFinished => true, - _ => false, - } + self.exit_strategy.is_finished(self.internal_state) } } @@ -227,6 +278,7 @@ fn poll_inner( select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>, side: PollNext, cx: &mut Context<'_>, + exit_strat: ExitStrategy, ) -> Poll> where St1: Stream, @@ -236,6 +288,9 @@ where Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), Poll::Ready(None) => { select.internal_state.finish(side); + if exit_strat.is_finished(*select.internal_state) { + return Poll::Ready(None); + } } Poll::Pending => (), }; @@ -259,11 +314,16 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); + let exit_strategy: ExitStrategy = *this.exit_strategy; + + if exit_strategy.is_finished(*this.internal_state) { + return Poll::Ready(None); + } match this.internal_state { InternalState::Start => { let next_side = (this.clos)(this.state); - poll_inner(&mut this, next_side, cx) + poll_inner(&mut this, next_side, cx, exit_strategy) } InternalState::LeftFinished => match this.stream2.poll_next(cx) { Poll::Ready(None) => { diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 6781a102d2..0c02179361 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -25,6 +25,22 @@ fn select() { select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]); } +#[test] +fn select_early_exit() { + fn select_and_compare(a: Vec, b: Vec, expected: Vec) { + let a = stream::iter(a); + let b = stream::iter(b); + let vec = block_on(stream::select_early_exit(a, b).collect::>()); + assert_eq!(vec, expected); + } + + select_and_compare(vec![1, 2, 3], vec![4, 5, 6], vec![1, 4, 2, 5, 3, 6]); + select_and_compare(vec![], vec![4, 5], vec![]); + select_and_compare(vec![4, 5], vec![], vec![4]); + select_and_compare(vec![1, 2, 3], vec![4, 5], vec![1, 4, 2, 5, 3]); + select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5]); +} + #[test] fn flat_map() { block_on(async {