Skip to content

Commit

Permalink
Remove Fuses from select, and only poll non-terminated streams
Browse files Browse the repository at this point in the history
  • Loading branch information
414owen committed Mar 21, 2022
1 parent 75de7a4 commit 53651ca
Showing 1 changed file with 96 additions and 35 deletions.
131 changes: 96 additions & 35 deletions futures-util/src/stream/select_with_strategy.rs
@@ -1,5 +1,4 @@
use super::assert_stream;
use crate::stream::{Fuse, StreamExt};
use core::{fmt, pin::Pin};
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
Expand All @@ -14,6 +13,13 @@ pub enum PollNext {
Right,
}

enum PollSide {
/// Poll the first stream.
Left,
/// Poll the second stream.
Right,
}

impl PollNext {
/// Toggle the value and return the old one.
#[must_use]
Expand All @@ -35,14 +41,40 @@ impl Default for PollNext {
}
}

enum InternalState {
Start,
LeftFinished,
RightFinished,
BothFinished,
}

impl InternalState {
fn finish(&mut self, ps: PollSide) {
match (&self, ps) {
(InternalState::Start, PollSide::Left) => {
*self = InternalState::LeftFinished;
}
(InternalState::Start, PollSide::Right) => {
*self = InternalState::RightFinished;
}
(InternalState::LeftFinished, PollSide::Right)
| (InternalState::RightFinished, PollSide::Left) => {
*self = InternalState::BothFinished;
}
_ => {}
}
}
}

pin_project! {
/// Stream for the [`select_with_strategy()`] function. See function docs for details.
#[must_use = "streams do nothing unless polled"]
pub struct SelectWithStrategy<St1, St2, Clos, State> {
#[pin]
stream1: Fuse<St1>,
stream1: St1,
#[pin]
stream2: Fuse<St2>,
stream2: St2,
internal_state: InternalState,
state: State,
clos: Clos,
}
Expand Down Expand Up @@ -121,9 +153,10 @@ where
State: Default,
{
assert_stream::<St1::Item, _>(SelectWithStrategy {
stream1: stream1.fuse(),
stream2: stream2.fuse(),
stream1,
stream2,
state: Default::default(),
internal_state: InternalState::Start,
clos: which,
})
}
Expand All @@ -132,7 +165,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// Acquires a reference to the underlying streams that this combinator is
/// pulling from.
pub fn get_ref(&self) -> (&St1, &St2) {
(self.stream1.get_ref(), self.stream2.get_ref())
(&self.stream1, &self.stream2)
}

/// Acquires a mutable reference to the underlying streams that this
Expand All @@ -141,7 +174,7 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// Note that care must be taken to avoid tampering with the state of the
/// stream which may otherwise confuse this combinator.
pub fn get_mut(&mut self) -> (&mut St1, &mut St2) {
(self.stream1.get_mut(), self.stream2.get_mut())
(&mut self.stream1, &mut self.stream2)
}

/// Acquires a pinned mutable reference to the underlying streams that this
Expand All @@ -151,15 +184,15 @@ impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
/// stream which may otherwise confuse this combinator.
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) {
let this = self.project();
(this.stream1.get_pin_mut(), this.stream2.get_pin_mut())
(this.stream1, this.stream2)
}

/// Consumes this combinator, returning the underlying streams.
///
/// Note that this may discard intermediate state of this combinator, so
/// care should be taken to avoid losing resources when this is called.
pub fn into_inner(self) -> (St1, St2) {
(self.stream1.into_inner(), self.stream2.into_inner())
(self.stream1, self.stream2)
}
}

Expand All @@ -170,7 +203,10 @@ where
Clos: FnMut(&mut State) -> PollNext,
{
fn is_terminated(&self) -> bool {
self.stream1.is_terminated() && self.stream2.is_terminated()
match self.internal_state {
InternalState::BothFinished => true,
_ => false,
}
}
}

Expand All @@ -185,35 +221,60 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
let this = self.project();

match (this.clos)(this.state) {
PollNext::Left => poll_inner(this.stream1, this.stream2, cx),
PollNext::Right => poll_inner(this.stream2, this.stream1, cx),
match this.internal_state {
InternalState::Start => match (this.clos)(this.state) {
PollNext::Left => {
match this.stream1.poll_next(cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
Poll::Ready(None) => {
this.internal_state.finish(PollSide::Left);
}
Poll::Pending => (),
};
match this.stream2.poll_next(cx) {
Poll::Ready(None) => {
this.internal_state.finish(PollSide::Right);
Poll::Ready(None)
}
a => a,
}
}
PollNext::Right => {
match this.stream2.poll_next(cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
Poll::Ready(None) => {
this.internal_state.finish(PollSide::Right);
}
Poll::Pending => (),
};
match this.stream1.poll_next(cx) {
Poll::Ready(None) => {
this.internal_state.finish(PollSide::Left);
Poll::Ready(None)
}
a => a,
}
}
},
InternalState::LeftFinished => match this.stream2.poll_next(cx) {
Poll::Ready(None) => {
*this.internal_state = InternalState::BothFinished;
Poll::Ready(None)
}
a => a,
},
InternalState::RightFinished => match this.stream1.poll_next(cx) {
Poll::Ready(None) => {
*this.internal_state = InternalState::BothFinished;
Poll::Ready(None)
}
a => a,
},
InternalState::BothFinished => Poll::Ready(None),
}
}
}

fn poll_inner<St1, St2>(
a: Pin<&mut St1>,
b: Pin<&mut St2>,
cx: &mut Context<'_>,
) -> Poll<Option<St1::Item>>
where
St1: Stream,
St2: Stream<Item = St1::Item>,
{
let a_done = match a.poll_next(cx) {
Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
Poll::Ready(None) => true,
Poll::Pending => false,
};

match b.poll_next(cx) {
Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
Poll::Ready(None) if a_done => Poll::Ready(None),
Poll::Ready(None) | Poll::Pending => Poll::Pending,
}
}

impl<St1, St2, Clos, State> fmt::Debug for SelectWithStrategy<St1, St2, Clos, State>
where
St1: fmt::Debug,
Expand Down

0 comments on commit 53651ca

Please sign in to comment.