diff --git a/futures-util/src/stream/select_with_strategy.rs b/futures-util/src/stream/select_with_strategy.rs index 6ccb321aaf..97417d3858 100644 --- a/futures-util/src/stream/select_with_strategy.rs +++ b/futures-util/src/stream/select_with_strategy.rs @@ -1,5 +1,4 @@ use super::assert_stream; -use crate::stream::{Fuse, StreamExt}; use core::{fmt, pin::Pin}; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; @@ -14,6 +13,13 @@ pub enum PollNext { Right, } +enum PollSide { + /// Poll the first stream. + Left, + /// Poll the second stream. + Right, +} + impl PollNext { /// Toggle the value and return the old one. #[must_use] @@ -35,14 +41,40 @@ impl Default for PollNext { } } +enum InternalState { + Start, + LeftFinished, + RightFinished, + BothFinished, +} + +impl InternalState { + fn finish(&mut self, ps: PollSide) { + match (&self, ps) { + (InternalState::Start, PollSide::Left) => { + *self = InternalState::LeftFinished; + } + (InternalState::Start, PollSide::Right) => { + *self = InternalState::RightFinished; + } + (InternalState::LeftFinished, PollSide::Right) + | (InternalState::RightFinished, PollSide::Left) => { + *self = InternalState::BothFinished; + } + _ => {} + } + } +} + pin_project! { /// Stream for the [`select_with_strategy()`] function. See function docs for details. #[must_use = "streams do nothing unless polled"] pub struct SelectWithStrategy { #[pin] - stream1: Fuse, + stream1: St1, #[pin] - stream2: Fuse, + stream2: St2, + internal_state: InternalState, state: State, clos: Clos, } @@ -121,9 +153,10 @@ where State: Default, { assert_stream::(SelectWithStrategy { - stream1: stream1.fuse(), - stream2: stream2.fuse(), + stream1, + stream2, state: Default::default(), + internal_state: InternalState::Start, clos: which, }) } @@ -132,7 +165,7 @@ impl SelectWithStrategy { /// Acquires a reference to the underlying streams that this combinator is /// pulling from. pub fn get_ref(&self) -> (&St1, &St2) { - (self.stream1.get_ref(), self.stream2.get_ref()) + (&self.stream1, &self.stream2) } /// Acquires a mutable reference to the underlying streams that this @@ -141,7 +174,7 @@ impl SelectWithStrategy { /// Note that care must be taken to avoid tampering with the state of the /// stream which may otherwise confuse this combinator. pub fn get_mut(&mut self) -> (&mut St1, &mut St2) { - (self.stream1.get_mut(), self.stream2.get_mut()) + (&mut self.stream1, &mut self.stream2) } /// Acquires a pinned mutable reference to the underlying streams that this @@ -151,7 +184,7 @@ impl SelectWithStrategy { /// stream which may otherwise confuse this combinator. pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) { let this = self.project(); - (this.stream1.get_pin_mut(), this.stream2.get_pin_mut()) + (this.stream1, this.stream2) } /// Consumes this combinator, returning the underlying streams. @@ -159,7 +192,7 @@ impl SelectWithStrategy { /// Note that this may discard intermediate state of this combinator, so /// care should be taken to avoid losing resources when this is called. pub fn into_inner(self) -> (St1, St2) { - (self.stream1.into_inner(), self.stream2.into_inner()) + (self.stream1, self.stream2) } } @@ -170,7 +203,10 @@ where Clos: FnMut(&mut State) -> PollNext, { fn is_terminated(&self) -> bool { - self.stream1.is_terminated() && self.stream2.is_terminated() + match self.internal_state { + InternalState::BothFinished => true, + _ => false, + } } } @@ -185,35 +221,60 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - match (this.clos)(this.state) { - PollNext::Left => poll_inner(this.stream1, this.stream2, cx), - PollNext::Right => poll_inner(this.stream2, this.stream1, cx), + match this.internal_state { + InternalState::Start => match (this.clos)(this.state) { + PollNext::Left => { + match this.stream1.poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) => { + this.internal_state.finish(PollSide::Left); + } + Poll::Pending => (), + }; + match this.stream2.poll_next(cx) { + Poll::Ready(None) => { + this.internal_state.finish(PollSide::Right); + Poll::Ready(None) + } + a => a, + } + } + PollNext::Right => { + match this.stream2.poll_next(cx) { + Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), + Poll::Ready(None) => { + this.internal_state.finish(PollSide::Right); + } + Poll::Pending => (), + }; + match this.stream1.poll_next(cx) { + Poll::Ready(None) => { + this.internal_state.finish(PollSide::Left); + Poll::Ready(None) + } + a => a, + } + } + }, + InternalState::LeftFinished => match this.stream2.poll_next(cx) { + Poll::Ready(None) => { + *this.internal_state = InternalState::BothFinished; + Poll::Ready(None) + } + a => a, + }, + InternalState::RightFinished => match this.stream1.poll_next(cx) { + Poll::Ready(None) => { + *this.internal_state = InternalState::BothFinished; + Poll::Ready(None) + } + a => a, + }, + InternalState::BothFinished => Poll::Ready(None), } } } -fn poll_inner( - a: Pin<&mut St1>, - b: Pin<&mut St2>, - cx: &mut Context<'_>, -) -> Poll> -where - St1: Stream, - St2: Stream, -{ - let a_done = match a.poll_next(cx) { - Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), - Poll::Ready(None) => true, - Poll::Pending => false, - }; - - match b.poll_next(cx) { - Poll::Ready(Some(item)) => Poll::Ready(Some(item)), - Poll::Ready(None) if a_done => Poll::Ready(None), - Poll::Ready(None) | Poll::Pending => Poll::Pending, - } -} - impl fmt::Debug for SelectWithStrategy where St1: fmt::Debug,