diff --git a/futures-util/Cargo.toml b/futures-util/Cargo.toml index 46ec854b0..e32b642aa 100644 --- a/futures-util/Cargo.toml +++ b/futures-util/Cargo.toml @@ -45,7 +45,7 @@ memchr = { version = "2.2", optional = true } futures_01 = { version = "0.1.25", optional = true, package = "futures" } tokio-io = { version = "0.1.9", optional = true } pin-utils = "0.1.0" -pin-project-lite = "0.2.4" +pin-project-lite = "0.2.6" [dev-dependencies] futures = { path = "../futures", features = ["async-await", "thread-pool"] } diff --git a/futures-util/src/stream/select_with_strategy.rs b/futures-util/src/stream/select_with_strategy.rs index bd86990cd..7423519df 100644 --- a/futures-util/src/stream/select_with_strategy.rs +++ b/futures-util/src/stream/select_with_strategy.rs @@ -1,5 +1,4 @@ use super::assert_stream; -use crate::stream::{Fuse, StreamExt}; use core::{fmt, pin::Pin}; use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; @@ -18,13 +17,15 @@ impl PollNext { /// Toggle the value and return the old one. pub fn toggle(&mut self) -> Self { let old = *self; + *self = self.other(); + old + } + fn other(&self) -> PollNext { match self { - PollNext::Left => *self = PollNext::Right, - PollNext::Right => *self = PollNext::Left, + PollNext::Left => PollNext::Right, + PollNext::Right => PollNext::Left, } - - old } } @@ -34,14 +35,41 @@ impl Default for PollNext { } } +enum InternalState { + Start, + LeftFinished, + RightFinished, + BothFinished, +} + +impl InternalState { + fn finish(&mut self, ps: PollNext) { + match (&self, ps) { + (InternalState::Start, PollNext::Left) => { + *self = InternalState::LeftFinished; + } + (InternalState::Start, PollNext::Right) => { + *self = InternalState::RightFinished; + } + (InternalState::LeftFinished, PollNext::Right) + | (InternalState::RightFinished, PollNext::Left) => { + *self = InternalState::BothFinished; + } + _ => {} + } + } +} + pin_project! { /// Stream for the [`select_with_strategy()`] function. See function docs for details. #[must_use = "streams do nothing unless polled"] + #[project = SelectWithStrategyProj] pub struct SelectWithStrategy { #[pin] - stream1: Fuse, + stream1: St1, #[pin] - stream2: Fuse, + stream2: St2, + internal_state: InternalState, state: State, clos: Clos, } @@ -120,9 +148,10 @@ where State: Default, { assert_stream::(SelectWithStrategy { - stream1: stream1.fuse(), - stream2: stream2.fuse(), + stream1, + stream2, state: Default::default(), + internal_state: InternalState::Start, clos: which, }) } @@ -131,7 +160,7 @@ impl SelectWithStrategy { /// Acquires a reference to the underlying streams that this combinator is /// pulling from. pub fn get_ref(&self) -> (&St1, &St2) { - (self.stream1.get_ref(), self.stream2.get_ref()) + (&self.stream1, &self.stream2) } /// Acquires a mutable reference to the underlying streams that this @@ -140,7 +169,7 @@ impl SelectWithStrategy { /// Note that care must be taken to avoid tampering with the state of the /// stream which may otherwise confuse this combinator. pub fn get_mut(&mut self) -> (&mut St1, &mut St2) { - (self.stream1.get_mut(), self.stream2.get_mut()) + (&mut self.stream1, &mut self.stream2) } /// Acquires a pinned mutable reference to the underlying streams that this @@ -150,7 +179,7 @@ impl SelectWithStrategy { /// stream which may otherwise confuse this combinator. pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) { let this = self.project(); - (this.stream1.get_pin_mut(), this.stream2.get_pin_mut()) + (this.stream1, this.stream2) } /// Consumes this combinator, returning the underlying streams. @@ -158,7 +187,7 @@ impl SelectWithStrategy { /// Note that this may discard intermediate state of this combinator, so /// care should be taken to avoid losing resources when this is called. pub fn into_inner(self) -> (St1, St2) { - (self.stream1.into_inner(), self.stream2.into_inner()) + (self.stream1, self.stream2) } } @@ -169,47 +198,88 @@ where Clos: FnMut(&mut State) -> PollNext, { fn is_terminated(&self) -> bool { - self.stream1.is_terminated() && self.stream2.is_terminated() + match self.internal_state { + InternalState::BothFinished => true, + _ => false, + } } } -impl Stream for SelectWithStrategy +#[inline] +fn poll_side( + select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>, + side: PollNext, + cx: &mut Context<'_>, +) -> Poll> where St1: Stream, St2: Stream, - Clos: FnMut(&mut State) -> PollNext, { - type Item = St1::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match (this.clos)(this.state) { - PollNext::Left => poll_inner(this.stream1, this.stream2, cx), - PollNext::Right => poll_inner(this.stream2, this.stream1, cx), - } + match side { + PollNext::Left => select.stream1.as_mut().poll_next(cx), + PollNext::Right => select.stream2.as_mut().poll_next(cx), } } -fn poll_inner( - a: Pin<&mut St1>, - b: Pin<&mut St2>, +#[inline] +fn poll_inner( + select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>, + side: PollNext, cx: &mut Context<'_>, ) -> Poll> where St1: Stream, St2: Stream, { - let a_done = match a.poll_next(cx) { + match poll_side(select, side, cx) { Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), - Poll::Ready(None) => true, - Poll::Pending => false, + Poll::Ready(None) => { + select.internal_state.finish(side); + } + Poll::Pending => (), }; + let other = side.other(); + match poll_side(select, other, cx) { + Poll::Ready(None) => { + select.internal_state.finish(other); + Poll::Ready(None) + } + a => a, + } +} + +impl Stream for SelectWithStrategy +where + St1: Stream, + St2: Stream, + Clos: FnMut(&mut State) -> PollNext, +{ + type Item = St1::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); - match b.poll_next(cx) { - Poll::Ready(Some(item)) => Poll::Ready(Some(item)), - Poll::Ready(None) if a_done => Poll::Ready(None), - Poll::Ready(None) | Poll::Pending => Poll::Pending, + match this.internal_state { + InternalState::Start => { + let next_side = (this.clos)(this.state); + poll_inner(&mut this, next_side, cx) + } + InternalState::LeftFinished => match this.stream2.poll_next(cx) { + Poll::Ready(None) => { + *this.internal_state = InternalState::BothFinished; + Poll::Ready(None) + } + a => a, + }, + InternalState::RightFinished => match this.stream1.poll_next(cx) { + Poll::Ready(None) => { + *this.internal_state = InternalState::BothFinished; + Poll::Ready(None) + } + a => a, + }, + InternalState::BothFinished => Poll::Ready(None), + } } }