diff --git a/futures-util/src/stream/stream/flat_map_unordered.rs b/futures-util/src/stream/stream/flat_map_unordered.rs index af8ba12c2e..fbd1434e22 100644 --- a/futures-util/src/stream/stream/flat_map_unordered.rs +++ b/futures-util/src/stream/stream/flat_map_unordered.rs @@ -13,6 +13,7 @@ use futures_core::task::{Context, Poll, Waker}; use futures_sink::Sink; use futures_task::{waker, ArcWake}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use core::cell::UnsafeCell; /// Indicates that there is nothing to poll and stream isn't polled at the /// moment. @@ -68,18 +69,24 @@ impl SharedPollState { /// Waker which will update `poll_state` with `need_to_poll` value on /// `wake_by_ref` call and then, if there is a need, call `inner_waker`. struct PollWaker { - inner_waker: Waker, + inner_waker: UnsafeCell>, poll_state: SharedPollState, need_to_poll: u8, } +unsafe impl Send for PollWaker {} + +unsafe impl Sync for PollWaker {} + impl ArcWake for PollWaker { fn wake_by_ref(self_arc: &Arc) { let poll_state_value = self_arc.poll_state.set_or(self_arc.need_to_poll); // Only call waker if stream isn't polled because it will be called // at the end of polling if state was changed. if poll_state_value & POLLING == NONE { - self_arc.inner_waker.wake_by_ref(); + if let Some(Some(inner_waker)) = unsafe { self_arc.inner_waker.get().as_ref() } { + inner_waker.wake_by_ref(); + } } } } @@ -135,6 +142,8 @@ pub struct FlatMapUnordered U> { stream: Map, limit: Option, is_stream_done: bool, + futures_waker: Arc, + stream_waker: Arc } impl Unpin for FlatMapUnordered @@ -173,16 +182,29 @@ where unsafe_unpinned!(is_stream_done: bool); unsafe_unpinned!(limit: Option); unsafe_unpinned!(poll_state: SharedPollState); + unsafe_unpinned!(futures_waker: Arc); + unsafe_unpinned!(stream_waker: Arc); pub(super) fn new(stream: St, limit: Option, f: F) -> FlatMapUnordered { + // Because to create first future, it needs to get inner + // stream from `stream` + let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM); FlatMapUnordered { - // Because to create first future, it needs to get inner - // stream from `stream` - poll_state: SharedPollState::new(NEED_TO_POLL_STREAM), futures: FuturesUnordered::new(), stream: Map::new(stream, f), is_stream_done: false, limit: limit.and_then(NonZeroUsize::new), + futures_waker: Arc::new(PollWaker { + inner_waker: UnsafeCell::new(None), + poll_state: poll_state.clone(), + need_to_poll: NEED_TO_POLL_FUTURES, + }), + stream_waker: Arc::new(PollWaker { + inner_waker: UnsafeCell::new(None), + poll_state: poll_state.clone(), + need_to_poll: NEED_TO_POLL_STREAM, + }), + poll_state } } @@ -218,28 +240,30 @@ where self.stream.into_inner() } - /// Creates waker with given `need_to_poll` value, which will be used to - /// update poll state on `wake_by_ref` call. - fn create_waker(&self, inner_waker: Waker, need_to_poll: u8) -> Waker { - waker(Arc::new(PollWaker { - inner_waker, - poll_state: self.poll_state.clone(), - need_to_poll, - })) - } - /// Creates special waker for polling stream which will set poll state /// to poll `stream` on `wake_by_ref` call. Use only if you need several /// contexts. - fn create_poll_stream_waker(&self, ctx: &Context<'_>) -> Waker { - self.create_waker(ctx.waker().clone(), NEED_TO_POLL_STREAM) + /// + /// ## Safety + /// + /// This function will modify current `stream_waker`'s `inner_waker` + /// via `UnsafeCell`, so it should be used only in `POLLING` phase. + unsafe fn create_poll_stream_waker(mut self: Pin<&mut Self>, ctx: &Context<'_>) -> Waker { + *self.as_mut().stream_waker.inner_waker.get() = ctx.waker().clone().into(); + waker(self.stream_waker.clone()) } /// Creates special waker for polling futures which willset poll state /// to poll `futures` on `wake_by_ref` call. Use only if you need several - /// contexts. - fn create_poll_futures_waker(&self, ctx: &Context<'_>) -> Waker { - self.create_waker(ctx.waker().clone(), NEED_TO_POLL_FUTURES) + /// contexts. + /// + /// ## Safety + /// + /// This function will modify current `futures_waker`'s `inner_waker` + /// via `UnsafeCell`, so it should be used only in `POLLING` phase. + unsafe fn create_poll_futures_waker(mut self: Pin<&mut Self>, ctx: &Context<'_>) -> Waker { + *self.as_mut().futures_waker.inner_waker.get() = ctx.waker().clone().into(); + waker(self.futures_waker.clone()) } /// Checks if current `futures` size is less than optional limit. @@ -273,15 +297,15 @@ where let mut poll_state_value = self.as_mut().poll_state().begin_polling(); let mut next_item = None; let mut need_to_poll_next = NONE; - let mut polling_with_two_wakers = - poll_state_value & NEED_TO_POLL == NEED_TO_POLL && self.not_exceeded_limit(); - let mut stream_will_be_woken = false; + let mut stream_will_be_woken_or_polled_later = !self.not_exceeded_limit(); let mut futures_will_be_woken = false; + let mut polling_with_two_wakers = poll_state_value & NEED_TO_POLL == NEED_TO_POLL && !stream_will_be_woken_or_polled_later; if poll_state_value & NEED_TO_POLL_STREAM != NONE { - if self.not_exceeded_limit() { + if !stream_will_be_woken_or_polled_later { match if polling_with_two_wakers { - let waker = self.create_poll_stream_waker(ctx); + // Safety: now state is `POLLING`. + let waker = unsafe { self.as_mut().create_poll_stream_waker(ctx) }; let mut ctx = Context::from_waker(&waker); self.as_mut().stream().poll_next(&mut ctx) } else { @@ -304,7 +328,7 @@ where polling_with_two_wakers = false; } Poll::Pending => { - stream_will_be_woken = true; + stream_will_be_woken_or_polled_later = true; if !polling_with_two_wakers { need_to_poll_next |= NEED_TO_POLL_STREAM; } @@ -317,7 +341,8 @@ where if poll_state_value & NEED_TO_POLL_FUTURES != NONE { match if polling_with_two_wakers { - let waker = self.create_poll_futures_waker(ctx); + // Safety: now state is `POLLING`. + let waker = unsafe { self.as_mut().create_poll_futures_waker(ctx) }; let mut ctx = Context::from_waker(&waker); self.as_mut().futures().poll_next(&mut ctx) } else { @@ -348,7 +373,7 @@ where if poll_state_value & NEED_TO_POLL != NONE && (polling_with_two_wakers || poll_state_value & NEED_TO_POLL_FUTURES != NONE && !futures_will_be_woken - || poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken) + || poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken_or_polled_later) { ctx.waker().wake_by_ref(); }