diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 7090b6cb6f..ec34a9cd2c 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -50,28 +50,47 @@ struct SharedPollState { } impl SharedPollState { - /// Constructs new `SharedPollState` with given state. - fn new(state: u8) -> SharedPollState { - SharedPollState { state: Arc::new(AtomicU8::new(state)) } + /// Constructs new `SharedPollState` with the given state. + fn new(value: u8) -> SharedPollState { + SharedPollState { state: Arc::new(AtomicU8::new(value)) } } /// Attempts to start polling, returning stored state in case of success. /// Returns `None` if state some waker is waking at the moment. - fn start_polling(&self) -> Option { + fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>)> { self.state - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { - if state & WAKING_ANYTHING == NONE { + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { + if value & WAKING_ANYTHING == NONE { Some(POLLING) } else { None } }) .ok() + .map(|value| { + ( + value, + PollStateBomb::new(self, move |state| { + state.stop_polling(NEED_TO_POLL_ALL); + }), + ) + }) } /// Starts the waking process and performs bitwise or with the given value. - fn start_waking(&self, to_poll: u8, waking: u8) -> u8 { - self.state.fetch_or(to_poll | waking, Ordering::SeqCst) + fn start_waking( + &self, + to_poll: u8, + waking: u8, + ) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) { + let value = self.state.fetch_or(to_poll | waking, Ordering::SeqCst); + + ( + value, + PollStateBomb::new(self, move |state| { + state.stop_waking(waking); + }), + ) } /// Toggles state to non-waking, allowing to start polling. @@ -82,8 +101,8 @@ impl SharedPollState { /// Sets current state to `!POLLING`, allowing to use wakers. fn stop_polling(&self, to_poll: u8) -> u8 { self.state - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { - Some((state | to_poll) & !POLLING) + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { + Some((value | to_poll) & !POLLING) }) .unwrap() } @@ -98,7 +117,6 @@ struct InnerWaker { } unsafe impl Send for InnerWaker {} - unsafe impl Sync for InnerWaker {} impl InnerWaker { @@ -115,20 +133,39 @@ impl InnerWaker { waker(self_arc.clone()) } - // Flags state that walking is started for the walker with the given value. - fn start_waking(&self) -> u8 { + // Flags state that waking is started for the waker with the given value. + fn start_waking(&self) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) { self.poll_state.start_waking(self.need_to_poll, self.need_to_poll << 3) } +} + +/// +struct PollStateBomb<'a, F: FnOnce(&SharedPollState)> { + state: &'a SharedPollState, + drop: Option, +} + +impl<'a, F: FnOnce(&SharedPollState)> PollStateBomb<'a, F> { + fn new(state: &'a SharedPollState, drop: F) -> Self { + Self { state, drop: Some(drop) } + } + + fn deactivate(mut self) { + self.drop.take(); + } +} - // Flags state that walking is finished for the walker with the given value. - fn stop_waking(&self) -> u8 { - self.poll_state.stop_waking(self.need_to_poll << 3) +impl Drop for PollStateBomb<'_, F> { + fn drop(&mut self) { + if let Some(drop) = self.drop.take() { + (drop)(&self.state); + } } } impl ArcWake for InnerWaker { fn wake_by_ref(self_arc: &Arc) { - let poll_state_value = self_arc.start_waking(); + let (poll_state_value, state_bomb) = self_arc.start_waking(); // Only call waker if stream isn't being polled because of safety reasons. // Waker will be called at the end of polling if state was changed. @@ -137,14 +174,11 @@ impl ArcWake for InnerWaker { unsafe { self_arc.inner_waker.get().as_ref().cloned().flatten() } { // First, stop waking to allow polling stream - self_arc.stop_waking(); + drop(state_bomb); // Wake inner waker inner_waker.wake(); - return; } } - - self_arc.stop_waking(); } } @@ -168,23 +202,19 @@ impl PollStreamFut { } } -impl Future for PollStreamFut { +impl Future for PollStreamFut { type Output = Option<(St::Item, PollStreamFut)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut stream = self.project().stream; - let item = if let Some(stream) = stream.as_mut().as_pin_mut() { ready!(stream.poll_next(cx)) } else { None }; + let out = item.map(|item| (item, PollStreamFut::new(stream.get_mut().take()))); - Poll::Ready( - item.map(|item| { - (item, PollStreamFut::new(unsafe { stream.get_unchecked_mut().take() })) - }), - ) + Poll::Ready(out) } } @@ -192,6 +222,7 @@ pin_project! { /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) /// method. #[must_use = "streams do nothing unless polled"] + #[project = FlattenUnorderedProj] pub struct FlattenUnordered { #[pin] inner_streams: FuturesUnordered>, @@ -224,7 +255,7 @@ where impl FlattenUnordered where St: Stream, - St::Item: Stream, + St::Item: Stream + Unpin, { pub(super) fn new(stream: St, limit: Option) -> FlattenUnordered { let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM); @@ -248,18 +279,24 @@ where } } + delegate_access_inner!(stream, St, ()); +} + +impl FlattenUnorderedProj<'_, St, St::Item> +where + St: Stream, +{ /// Checks if current `inner_streams` size is less than optional limit. fn is_exceeded_limit(&self) -> bool { - self.limit.map(|limit| self.inner_streams.len() >= limit.get()).unwrap_or(false) + self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get()) } - - delegate_access_inner!(stream, St, ()); } impl FusedStream for FlattenUnordered where St: FusedStream, - St::Item: FusedStream, + St::Item: FusedStream + Unpin, + ::Item: core::fmt::Debug, { fn is_terminated(&self) -> bool { self.stream.is_terminated() && self.inner_streams.is_empty() @@ -269,18 +306,18 @@ where impl Stream for FlattenUnordered where St: Stream, - St::Item: Stream, + St::Item: Stream + Unpin, + ::Item: core::fmt::Debug, { type Item = ::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut next_item = None; let mut need_to_poll_next = NONE; - let limit_exceeded = self.is_exceeded_limit(); let mut this = self.as_mut().project(); - let mut poll_state_value = match this.poll_state.start_polling() { + let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() { Some(value) => value, _ => { // Waker was called, just wait for the next poll @@ -289,44 +326,48 @@ where }; let mut polling_with_two_wakers = - !limit_exceeded && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL; + !this.is_exceeded_limit() && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL; if poll_state_value & NEED_TO_POLL_STREAM != NONE { - if !limit_exceeded && !*this.is_stream_done { - match if polling_with_two_wakers { - // Safety: now state is `POLLING`. - let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) }; - let mut cx = Context::from_waker(&waker); - this.stream.as_mut().poll_next(&mut cx) + loop { + if this.is_exceeded_limit() || *this.is_stream_done { + polling_with_two_wakers = false; + need_to_poll_next |= NEED_TO_POLL_STREAM; + + break; } else { - this.stream.as_mut().poll_next(cx) - } { - Poll::Ready(Some(inner_stream)) => { - this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream)); - need_to_poll_next |= NEED_TO_POLL_STREAM; - // Polling inner streams in current iteration with the same context - // is ok because we already received `Poll::Ready` from - // stream - poll_state_value |= NEED_TO_POLL_INNER_STREAMS; - polling_with_two_wakers = false; - *this.is_stream_done = false; - } - Poll::Ready(None) => { - // Polling inner streams in current iteration with the same context - // is ok because we already received `Poll::Ready` from - // stream - polling_with_two_wakers = false; - *this.is_stream_done = true; - } - Poll::Pending => { - if !polling_with_two_wakers { + match if polling_with_two_wakers { + // Safety: now state is `POLLING`. + let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) }; + let mut cx = Context::from_waker(&waker); + this.stream.as_mut().poll_next(&mut cx) + } else { + this.stream.as_mut().poll_next(cx) + } { + Poll::Ready(Some(inner_stream)) => { + this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream)); need_to_poll_next |= NEED_TO_POLL_STREAM; + // Polling inner streams in current iteration with the same context + // is ok because we already received `Poll::Ready` from + // stream + poll_state_value |= NEED_TO_POLL_INNER_STREAMS; + *this.is_stream_done = false; + } + Poll::Ready(None) => { + // Polling inner streams in current iteration with the same context + // is ok because we already received `Poll::Ready` from + // stream + *this.is_stream_done = true; + } + Poll::Pending => { + if !polling_with_two_wakers { + need_to_poll_next |= NEED_TO_POLL_STREAM; + } + *this.is_stream_done = false; + break; } - *this.is_stream_done = false; } } - } else { - need_to_poll_next |= NEED_TO_POLL_STREAM; } } @@ -345,7 +386,7 @@ where need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; } Poll::Ready(Some(None)) => { - need_to_poll_next |= NEED_TO_POLL_ALL; + need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; } Poll::Pending => { if !polling_with_two_wakers { @@ -358,6 +399,7 @@ where } } + state_bomb.deactivate(); poll_state_value = this.poll_state.stop_polling(need_to_poll_next); let is_done = *this.is_stream_done && this.inner_streams.is_empty(); @@ -365,7 +407,7 @@ where Poll::Ready(next_item) } else { if poll_state_value & NEED_TO_POLL_ALL != NONE - || !self.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE + || !this.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE { cx.waker().wake_by_ref(); } diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index 5872c60447..f70e1aa256 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -210,7 +210,7 @@ delegate_all!( FlattenUnordered( flatten_unordered::FlattenUnordered ): Debug + Sink + Stream + FusedStream + AccessInner[St, (.)] + New[|x: St, limit: Option| flatten_unordered::FlattenUnordered::new(x, limit)] - where St: Stream, St::Item: Stream + where St: Stream, St::Item: Stream, St::Item: Unpin ); #[cfg(not(futures_no_atomic_cas))] @@ -220,7 +220,7 @@ delegate_all!( FlatMapUnordered( FlattenUnordered> ): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + New[|x: St, limit: Option, f: F| FlattenUnordered::new(Map::new(x, f), limit)] - where St: Stream, U: Stream, F: FnMut(St::Item) -> U + where St: Stream, U: Stream, U: Unpin, F: FnMut(St::Item) -> U ); #[cfg(not(futures_no_atomic_cas))] @@ -790,11 +790,11 @@ pub trait StreamExt: Stream { /// assert_eq!(output, vec![1, 2, 3, 4]); /// # }); /// ``` - #[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))] + #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] fn flatten_unordered(self, limit: impl Into>) -> FlattenUnordered where - Self::Item: Stream, + Self::Item: Stream + Unpin, Self: Sized, { FlattenUnordered::new(self, limit.into()) @@ -871,7 +871,8 @@ pub trait StreamExt: Stream { /// /// assert_eq!(vec![1usize, 2, 2, 3, 3, 3, 4, 4, 4, 4], values); /// # }); - #[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))] + /// ``` + #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] fn flat_map_unordered( self, @@ -879,7 +880,7 @@ pub trait StreamExt: Stream { f: F, ) -> FlatMapUnordered where - U: Stream, + U: Stream + Unpin, F: FnMut(Self::Item) -> U, Self: Sized, { diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index d72b1dffd4..1b2d613154 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -72,7 +72,7 @@ fn scan() { #[test] fn flatten_unordered() { use futures::executor::block_on; - use futures::stream::{self, *}; + use futures::stream::*; use futures::task::*; use std::convert::identity; use std::pin::Pin; @@ -92,7 +92,8 @@ fn flatten_unordered() { if !self.polled { if !self.wake_immediately { let waker = ctx.waker().clone(); - let sleep_time = Duration::from_millis(*self.data.last().unwrap_or(&0) as u64); + let sleep_time = + Duration::from_millis(*self.data.first().unwrap_or(&0) as u64 / 10); thread::spawn(move || { thread::sleep(sleep_time); waker.wake_by_ref(); @@ -133,7 +134,7 @@ fn flatten_unordered() { } Poll::Pending } else { - let data: Vec<_> = (0..6).map(|v| v + self.base * 6).collect(); + let data: Vec<_> = (0..6).rev().map(|v| v + self.base * 6).collect(); self.base += 1; self.polled = false; Poll::Ready(Some(DataStream { @@ -148,17 +149,11 @@ fn flatten_unordered() { // basic behaviour block_on(async { let st = - stream::iter(vec![stream::iter(0..=4u8), stream::iter(6..=10), stream::iter(0..=2)]); - - let mut fl_unordered = st - .map(|s| s.filter(|v| futures::future::ready(v % 2 == 0))) - .flatten_unordered(1) - .collect::>() - .await; + stream::iter(vec![stream::iter(0..=4u8), stream::iter(6..=10), stream::iter(10..=12)]); - fl_unordered.sort(); + let fl_unordered = st.flatten_unordered(3).collect::>().await; - assert_eq!(fl_unordered, vec![0, 0, 2, 2, 4, 6, 8, 10]); + assert_eq!(fl_unordered, vec![0, 6, 10, 1, 7, 11, 2, 8, 12, 3, 9, 4, 10]); }); block_on(async { @@ -206,7 +201,7 @@ fn flatten_unordered() { let mut fl_unordered = Interchanger { polled: false, base: 0, wake_immediately: false } .take(10) .map(|s| s.map(identity)) - .flatten() + .flatten_unordered(10) .collect::>() .await; @@ -236,7 +231,7 @@ fn flatten_unordered() { Interchanger { polled: false, base: 0, wake_immediately: false } .take(10) .map(|s| s.map(identity)) - .flatten() + .flatten_unordered(10) .collect::>() ); @@ -289,9 +284,40 @@ fn flatten_unordered() { assert_eq!(values, (0..60).collect::>()); }); } + + // stream panics + let st = once(async { once(async { panic!("Polled") }).boxed() }.boxed()).chain( + Interchanger { polled: false, base: 0, wake_immediately: true } + .then(|val| async move { val.boxed() }.boxed()) + .take(10), + ); + + let stream = Arc::new(Mutex::new(st.boxed().flat_map_unordered(10, |s| s.map(identity)))); + + std::thread::spawn({ + let stream = stream.clone(); + move || { + let mut st = poll_fn(|cx| { + let mut lock = ready!(stream.lock().poll_unpin(cx)); + let data = ready!(lock.poll_next_unpin(cx)); + + Poll::Ready(data) + }); + + block_on(st.next()) + } + }) + .join() + .unwrap_err(); + + block_on(async move { + let mut values: Vec<_> = stream.lock().await.by_ref().collect().await; + values.sort(); + + assert_eq!(values, (0..60).collect::>()); + }); } -#[cfg(feature = "executor")] // executor:: #[test] fn take_until() { fn make_stop_fut(stop_on: u32) -> impl Future {