From 9ae19f89327fe639f5c7f281254924e0060741b1 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Sat, 20 Aug 2022 23:11:24 -0700 Subject: [PATCH] Fix incorrect termination of `select_with_strategy` streams (#2635) --- .../src/stream/select_with_strategy.rs | 11 +++-- futures/tests/stream.rs | 42 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/futures-util/src/stream/select_with_strategy.rs b/futures-util/src/stream/select_with_strategy.rs index 7423519df1..224d5f821c 100644 --- a/futures-util/src/stream/select_with_strategy.rs +++ b/futures-util/src/stream/select_with_strategy.rs @@ -231,18 +231,23 @@ where St1: Stream, St2: Stream, { - match poll_side(select, side, cx) { + let first_done = match poll_side(select, side, cx) { Poll::Ready(Some(item)) => return Poll::Ready(Some(item)), Poll::Ready(None) => { select.internal_state.finish(side); + true } - Poll::Pending => (), + Poll::Pending => false, }; let other = side.other(); match poll_side(select, other, cx) { Poll::Ready(None) => { select.internal_state.finish(other); - Poll::Ready(None) + if first_done { + Poll::Ready(None) + } else { + Poll::Pending + } } a => a, } diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 71ec654bfb..5cde45833f 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -1,5 +1,9 @@ +use std::cell::Cell; use std::iter; +use std::pin::Pin; +use std::rc::Rc; use std::sync::Arc; +use std::task::Context; use futures::channel::mpsc; use futures::executor::block_on; @@ -9,6 +13,7 @@ use futures::sink::SinkExt; use futures::stream::{self, StreamExt}; use futures::task::Poll; use futures::{ready, FutureExt}; +use futures_core::Stream; use futures_test::task::noop_context; #[test] @@ -419,3 +424,40 @@ fn ready_chunks() { assert_eq!(s.next().await.unwrap(), vec![4]); }); } + +struct SlowStream { + times_should_poll: usize, + times_polled: Rc>, +} +impl Stream for SlowStream { + type Item = usize; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.times_polled.set(self.times_polled.get() + 1); + if self.times_polled.get() % 2 == 0 { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if self.times_polled.get() >= self.times_should_poll { + return Poll::Ready(None); + } + Poll::Ready(Some(self.times_polled.get())) + } +} + +#[test] +fn select_with_strategy_doesnt_terminate_early() { + for side in [stream::PollNext::Left, stream::PollNext::Right] { + let times_should_poll = 10; + let count = Rc::new(Cell::new(0)); + let b = stream::iter([10, 20]); + + let mut selected = stream::select_with_strategy( + SlowStream { times_should_poll, times_polled: count.clone() }, + b, + |_: &mut ()| side, + ); + block_on(async move { while selected.next().await.is_some() {} }); + assert_eq!(count.get(), times_should_poll + 1); + } +}