diff --git a/futures-util/src/stream/mod.rs b/futures-util/src/stream/mod.rs index 517e9519b8..f0db571413 100644 --- a/futures-util/src/stream/mod.rs +++ b/futures-util/src/stream/mod.rs @@ -92,6 +92,9 @@ pub use self::poll_fn::{poll_fn, PollFn}; mod select; pub use self::select::{select, Select}; +mod select_with_strategy; +pub use self::select_with_strategy::{select_with_strategy, PollNext, SelectWithStrategy}; + mod unfold; pub use self::unfold::{unfold, Unfold}; diff --git a/futures-util/src/stream/select.rs b/futures-util/src/stream/select.rs index 133ac6c7ac..0c1e3af782 100644 --- a/futures-util/src/stream/select.rs +++ b/futures-util/src/stream/select.rs @@ -1,5 +1,5 @@ use super::assert_stream; -use crate::stream::{Fuse, StreamExt}; +use crate::stream::{select_with_strategy, PollNext, SelectWithStrategy}; use core::pin::Pin; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; @@ -11,10 +11,7 @@ pin_project! { #[must_use = "streams do nothing unless polled"] pub struct Select { #[pin] - stream1: Fuse, - #[pin] - stream2: Fuse, - flag: bool, + inner: SelectWithStrategy PollNext, PollNext>, } } @@ -22,21 +19,42 @@ pin_project! { /// stream will be polled in a round-robin fashion, and whenever a stream is /// ready to yield an item that item is yielded. /// -/// After one of the two input stream completes, the remaining one will be +/// After one of the two input streams completes, the remaining one will be /// polled exclusively. The returned stream completes when both input /// streams have completed. /// /// Note that this function consumes both streams and returns a wrapped /// version of them. +/// +/// ## Examples +/// +/// ```rust +/// # futures::executor::block_on(async { +/// use futures::stream::{ repeat, select, StreamExt }; +/// +/// let left = repeat(1); +/// let right = repeat(2); +/// +/// let mut out = select(left, right); +/// +/// for _ in 0..100 { +/// // We should be alternating. +/// assert_eq!(1, out.select_next_some().await); +/// assert_eq!(2, out.select_next_some().await); +/// } +/// # }); +/// ``` pub fn select(stream1: St1, stream2: St2) -> Select where St1: Stream, St2: Stream, { + fn round_robin(last: &mut PollNext) -> PollNext { + last.toggle() + } + assert_stream::(Select { - stream1: stream1.fuse(), - stream2: stream2.fuse(), - flag: false, + inner: select_with_strategy(stream1, stream2, round_robin), }) } @@ -44,7 +62,7 @@ impl Select { /// Acquires a reference to the underlying streams that this combinator is /// pulling from. pub fn get_ref(&self) -> (&St1, &St2) { - (self.stream1.get_ref(), self.stream2.get_ref()) + self.inner.get_ref() } /// Acquires a mutable reference to the underlying streams that this @@ -53,7 +71,7 @@ impl Select { /// Note that care must be taken to avoid tampering with the state of the /// stream which may otherwise confuse this combinator. pub fn get_mut(&mut self) -> (&mut St1, &mut St2) { - (self.stream1.get_mut(), self.stream2.get_mut()) + self.inner.get_mut() } /// Acquires a pinned mutable reference to the underlying streams that this @@ -63,7 +81,7 @@ impl Select { /// stream which may otherwise confuse this combinator. pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) { let this = self.project(); - (this.stream1.get_pin_mut(), this.stream2.get_pin_mut()) + this.inner.get_pin_mut() } /// Consumes this combinator, returning the underlying streams. @@ -71,7 +89,7 @@ impl Select { /// Note that this may discard intermediate state of this combinator, so /// care should be taken to avoid losing resources when this is called. pub fn into_inner(self) -> (St1, St2) { - (self.stream1.into_inner(), self.stream2.into_inner()) + self.inner.into_inner() } } @@ -81,7 +99,7 @@ where St2: Stream, { fn is_terminated(&self) -> bool { - self.stream1.is_terminated() && self.stream2.is_terminated() + self.inner.is_terminated() } } @@ -94,37 +112,6 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - if !*this.flag { - poll_inner(this.flag, this.stream1, this.stream2, cx) - } else { - poll_inner(this.flag, this.stream2, this.stream1, cx) - } - } -} - -fn poll_inner( - flag: &mut bool, - a: Pin<&mut St1>, - b: Pin<&mut St2>, - cx: &mut Context<'_>, -) -> Poll> -where - St1: Stream, - St2: Stream, -{ - let a_done = match a.poll_next(cx) { - Poll::Ready(Some(item)) => { - // give the other stream a chance to go first next time - *flag = !*flag; - return Poll::Ready(Some(item)); - } - Poll::Ready(None) => true, - Poll::Pending => false, - }; - - match b.poll_next(cx) { - Poll::Ready(Some(item)) => Poll::Ready(Some(item)), - Poll::Ready(None) if a_done => Poll::Ready(None), - Poll::Ready(None) | Poll::Pending => Poll::Pending, + this.inner.poll_next(cx) } } diff --git a/futures-util/src/stream/select_with_strategy.rs b/futures-util/src/stream/select_with_strategy.rs new file mode 100644 index 0000000000..1b501132c5 --- /dev/null +++ b/futures-util/src/stream/select_with_strategy.rs @@ -0,0 +1,230 @@ +use super::assert_stream; +use crate::stream::{Fuse, StreamExt}; +use core::{fmt, pin::Pin}; +use futures_core::stream::{FusedStream, Stream}; +use futures_core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +/// Type to tell [`SelectWithStrategy`] which stream to poll next. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] +pub enum PollNext { + /// Poll the first stream. + Left, + /// Poll the second stream. + Right, +} + +impl PollNext { + /// Toggle the value and return the old one. + pub fn toggle(&mut self) -> Self { + let old = *self; + + match self { + PollNext::Left => *self = PollNext::Right, + PollNext::Right => *self = PollNext::Left, + } + + old + } +} + +impl Default for PollNext { + fn default() -> Self { + PollNext::Left + } +} + +pin_project! { + /// Stream for the [`select_with_strategy()`] function. See function docs for details. + #[must_use = "streams do nothing unless polled"] + pub struct SelectWithStrategy { + #[pin] + stream1: Fuse, + #[pin] + stream2: Fuse, + state: State, + clos: Clos, + } +} + +/// This function will attempt to pull items from both streams. You provide a +/// closure to tell [`SelectWithStrategy`] which stream to poll. The closure can +/// store state on `SelectWithStrategy` to which it will receive a `&mut` on every +/// invocation. This allows basing the strategy on prior choices. +/// +/// After one of the two input streams completes, the remaining one will be +/// polled exclusively. The returned stream completes when both input +/// streams have completed. +/// +/// Note that this function consumes both streams and returns a wrapped +/// version of them. +/// +/// ## Examples +/// +/// ### Priority +/// This example shows how to always prioritize the left stream. +/// +/// ```rust +/// # futures::executor::block_on(async { +/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt }; +/// +/// let left = repeat(1); +/// let right = repeat(2); +/// +/// // We don't need any state, so let's make it an empty tuple. +/// // We must provide some type here, as there is no way for the compiler +/// // to infer it. As we don't need to capture variables, we can just +/// // use a function pointer instead of a closure. +/// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left } +/// +/// let mut out = select_with_strategy(left, right, prio_left); +/// +/// for _ in 0..100 { +/// // Whenever we poll out, we will alwas get `1`. +/// assert_eq!(1, out.select_next_some().await); +/// } +/// # }); +/// ``` +/// +/// ### Round Robin +/// This example shows how to select from both streams round robin. +/// Note: this special case is provided by [`futures-util::stream::select`]. +/// +/// ```rust +/// # futures::executor::block_on(async { +/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt }; +/// +/// let left = repeat(1); +/// let right = repeat(2); +/// +/// // We don't need any state, so let's make it an empty tuple. +/// let rrobin = |last: &mut PollNext| last.toggle(); +/// +/// let mut out = select_with_strategy(left, right, rrobin); +/// +/// for _ in 0..100 { +/// // We should be alternating now. +/// assert_eq!(1, out.select_next_some().await); +/// assert_eq!(2, out.select_next_some().await); +/// } +/// # }); +/// ``` +pub fn select_with_strategy( + stream1: St1, + stream2: St2, + which: Clos, +) -> SelectWithStrategy +where + St1: Stream, + St2: Stream, + Clos: FnMut(&mut State) -> PollNext, + State: Default, +{ + assert_stream::(SelectWithStrategy { + stream1: stream1.fuse(), + stream2: stream2.fuse(), + state: Default::default(), + clos: which, + }) +} + +impl SelectWithStrategy { + /// Acquires a reference to the underlying streams that this combinator is + /// pulling from. + pub fn get_ref(&self) -> (&St1, &St2) { + (self.stream1.get_ref(), self.stream2.get_ref()) + } + + /// Acquires a mutable reference to the underlying streams that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_mut(&mut self) -> (&mut St1, &mut St2) { + (self.stream1.get_mut(), self.stream2.get_mut()) + } + + /// Acquires a pinned mutable reference to the underlying streams that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) { + let this = self.project(); + (this.stream1.get_pin_mut(), this.stream2.get_pin_mut()) + } + + /// Consumes this combinator, returning the underlying streams. + /// + /// Note that this may discard intermediate state of this combinator, so + /// care should be taken to avoid losing resources when this is called. + pub fn into_inner(self) -> (St1, St2) { + (self.stream1.into_inner(), self.stream2.into_inner()) + } +} + +impl FusedStream for SelectWithStrategy +where + St1: Stream, + St2: Stream, + Clos: FnMut(&mut State) -> PollNext, +{ + fn is_terminated(&self) -> bool { + self.stream1.is_terminated() && self.stream2.is_terminated() + } +} + +impl Stream for SelectWithStrategy +where + St1: Stream, + St2: Stream, + Clos: FnMut(&mut State) -> PollNext, +{ + type Item = St1::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match (this.clos)(this.state) { + PollNext::Left => poll_inner(this.stream1, this.stream2, cx), + PollNext::Right => poll_inner(this.stream2, this.stream1, cx), + } + } +} + +fn poll_inner( + a: Pin<&mut St1>, + b: Pin<&mut St2>, + cx: &mut Context<'_>, +) -> Poll> +where + St1: Stream, + St2: Stream, +{ + let a_done = match a.poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) => true, + Poll::Pending => false, + }; + + match b.poll_next(cx) { + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) if a_done => Poll::Ready(None), + Poll::Ready(None) | Poll::Pending => Poll::Pending, + } +} + +impl fmt::Debug for SelectWithStrategy +where + St1: fmt::Debug, + St2: fmt::Debug, + State: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SelectWithStrategy") + .field("stream1", &self.stream1) + .field("stream2", &self.stream2) + .field("state", &self.state) + .finish() + } +}