diff --git a/futures-util/src/stream/stream/flat_map_unordered.rs b/futures-util/src/stream/stream/flat_map_unordered.rs new file mode 100644 index 0000000000..e86b180adb --- /dev/null +++ b/futures-util/src/stream/stream/flat_map_unordered.rs @@ -0,0 +1,376 @@ +use super::Map; +use crate::stream::FuturesUnordered; +use core::fmt; +use core::num::NonZeroUsize; +use core::pin::Pin; +use futures_core::future::Future; +use futures_core::stream::FusedStream; +use futures_core::stream::Stream; +use futures_core::task::{Context, Poll, Waker}; +#[cfg(feature = "sink")] +use futures_sink::Sink; +use futures_task::{waker, ArcWake}; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::sync::atomic::*; +use std::sync::Arc; + +/// Indicates that there is nothing to poll and stream isn't polled at the +/// moment. +const NONE: u8 = 0; + +/// Indicates that `futures` need to be polled. +const NEED_TO_POLL_FUTURES: u8 = 0b1; + +/// Indicates that `stream` needs to be polled. +const NEED_TO_POLL_STREAM: u8 = 0b10; + +/// Indicates that we need to poll something. +const NEED_TO_POLL: u8 = NEED_TO_POLL_FUTURES | NEED_TO_POLL_STREAM; + +/// Indicates that current stream is polled at the moment. +const POLLING: u8 = 0b100; + +/// State which used to determine what needs to be polled, +/// and are we polling stream at the moment or not. +#[derive(Clone, Debug)] +struct SharedPollState { + state: Arc, +} + +impl SharedPollState { + /// Constructs new `SharedPollState` with given state. + fn new(state: u8) -> Self { + Self { + state: Arc::new(AtomicU8::new(state)), + } + } + + /// Swaps state with `POLLING`, returning previous state. + fn begin_polling(&self) -> u8 { + self.state.swap(POLLING, Ordering::AcqRel) + } + + /// Performs bitwise or with `to_poll` and given state, returning + /// previous state. + fn set_or(&self, to_poll: u8) -> u8 { + self.state.fetch_or(to_poll, Ordering::AcqRel) + } + + /// Performs bitwise or with `to_poll` and current state, stores result + /// with non-`POLLING` state, and returns disjunction result. + fn end_polling(&self, to_poll: u8) -> u8 { + let to_poll = to_poll | self.state.load(Ordering::Acquire); + self.state.store(to_poll & !POLLING, Ordering::Release); + to_poll + } +} + +/// Waker which will update `poll_state` with `need_to_poll` value on +/// `wake_by_ref` call and then, if there is a need, call `inner_waker`. +struct PollWaker { + inner_waker: Waker, + poll_state: SharedPollState, + need_to_poll: u8, +} + +impl ArcWake for PollWaker { + fn wake_by_ref(self_arc: &Arc) { + let poll_state_value = self_arc.poll_state.set_or(self_arc.need_to_poll); + // Only call waker if we're not polling because we will call it at the end + // of polling if it needs to poll something. + if poll_state_value & POLLING == NONE { + self_arc.inner_waker.wake_by_ref(); + } + } +} + +/// Future which contains optional stream. If it's `Some`, it will attempt +/// to call `poll_next` on it, returning `Some((item, stream))` in case of +/// `Poll::Ready(Some(...))` or `None` in case of `Poll::Ready(None)`. +/// If `poll_next` will return `Poll::Pending`, it will be forwared to +/// the future, and current task will be notified by waker. +#[must_use = "futures do nothing unless you `.await` or poll them"] +struct StreamFut { + stream: Option, +} + +impl StreamFut { + unsafe_pinned!(stream: Option); +} + +impl Unpin for StreamFut {} + +impl Future for StreamFut { + type Output = Option<(St::Item, St)>; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let item = if let Some(stream) = self.as_mut().stream().as_pin_mut() { + ready!(stream.poll_next(ctx)) + } else { + None + }; + + Poll::Ready(item.map(|item| { + (item, unsafe { + self.get_unchecked_mut().stream.take().unwrap() + }) + })) + } +} + +/// Stream for the [`flat_map_unordered`](super::StreamExt::flat_map_unordered) +/// method. +#[must_use = "streams do nothing unless polled"] +pub struct FlatMapUnordered U> { + poll_state: SharedPollState, + futures: FuturesUnordered>, + stream: Map, + limit: Option, + is_stream_done: bool, +} + +impl Unpin for FlatMapUnordered +where + St: Stream + Unpin, + U: Stream + Unpin, + F: FnMut(St::Item) -> U, +{ +} + +impl fmt::Debug for FlatMapUnordered +where + St: Stream + fmt::Debug, + U: Stream + fmt::Debug, + F: FnMut(St::Item) -> U, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FlatMapUnordered") + .field("poll_state", &self.poll_state) + .field("futures", &self.futures) + .field("limit", &self.limit) + .field("stream", &self.stream) + .field("is_stream_done", &self.is_stream_done) + .finish() + } +} + +impl FlatMapUnordered +where + St: Stream, + U: Stream, + F: FnMut(St::Item) -> U, +{ + unsafe_pinned!(futures: FuturesUnordered>); + unsafe_pinned!(stream: Map); + unsafe_unpinned!(is_stream_done: bool); + unsafe_unpinned!(limit: Option); + unsafe_unpinned!(poll_state: SharedPollState); + + pub(super) fn new(stream: St, limit: Option, f: F) -> FlatMapUnordered { + FlatMapUnordered { + // Because to create first future, we need to get inner + // stream from `stream` + poll_state: SharedPollState::new(NEED_TO_POLL_STREAM), + futures: FuturesUnordered::new(), + stream: Map::new(stream, f), + is_stream_done: false, + limit: limit.and_then(NonZeroUsize::new), + } + } + + /// Acquires a reference to the underlying stream that this combinator is + /// pulling from. + pub fn get_ref(&self) -> &St { + self.stream.get_ref() + } + + /// Acquires a mutable reference to the underlying stream that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_mut(&mut self) -> &mut St { + self.stream.get_mut() + } + + /// Acquires a pinned mutable reference to the underlying stream that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut St> { + self.stream().get_pin_mut() + } + + /// Consumes this combinator, returning the underlying stream. + /// + /// Note that this may discard intermediate state of this combinator, so + /// care should be taken to avoid losing resources when this is called. + pub fn into_inner(self) -> St { + self.stream.into_inner() + } + + /// Creates waker with given `need_to_poll` value, which will be used to + /// update poll state on `wake_by_ref` call. + fn create_waker(&self, inner_waker: Waker, need_to_poll: u8) -> Waker { + waker(Arc::new(PollWaker { + inner_waker, + poll_state: self.poll_state.clone(), + need_to_poll, + })) + } + + /// Creates special waker for polling stream which will set poll state + /// to poll `stream` on `wake_by_ref` call. Use only if you need several + /// contexts. + fn create_poll_stream_waker(&self, ctx: &Context<'_>) -> Waker { + self.create_waker(ctx.waker().clone(), NEED_TO_POLL_STREAM) + } + + /// Creates special waker for polling futures which willset poll state + /// to poll `futures` on `wake_by_ref` call. Use only if you need several + /// contexts. + fn create_poll_futures_waker(&self, ctx: &Context<'_>) -> Waker { + self.create_waker(ctx.waker().clone(), NEED_TO_POLL_FUTURES) + } + + /// Checks if current `futures` size is less than optional limit. + fn not_exceeded_limit(&self) -> bool { + self.limit + .map(|limit| self.futures.len() < limit.get()) + .unwrap_or(true) + } +} + +impl FusedStream for FlatMapUnordered +where + St: FusedStream, + U: Unpin + FusedStream, + F: FnMut(St::Item) -> U, +{ + fn is_terminated(&self) -> bool { + self.futures.is_empty() && self.stream.is_terminated() + } +} + +impl Stream for FlatMapUnordered +where + St: Stream, + U: Stream, + F: FnMut(St::Item) -> U, +{ + type Item = U::Item; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let mut poll_state_value = self.as_mut().poll_state().begin_polling(); + + let mut next_item = None; + let mut need_to_poll_next = NONE; + let mut polling_with_two_wakers = + poll_state_value & NEED_TO_POLL == NEED_TO_POLL && self.not_exceeded_limit(); + let mut polled_stream = false; + let mut polled_futures = false; + + if poll_state_value & NEED_TO_POLL_STREAM != NONE { + if self.not_exceeded_limit() { + polled_stream = true; + match if polling_with_two_wakers { + let waker = self.create_poll_stream_waker(ctx); + let mut ctx = Context::from_waker(&waker); + self.as_mut().stream().poll_next(&mut ctx) + } else { + self.as_mut().stream().poll_next(ctx) + } { + Poll::Ready(Some(inner_stream)) => { + self.as_mut().futures().push(StreamFut { + stream: Some(inner_stream), + }); + need_to_poll_next |= NEED_TO_POLL_STREAM; + // Polling futures in current iteration with the same context + // is ok because we already received `Poll::Ready` from + // stream + poll_state_value |= NEED_TO_POLL_FUTURES; + polling_with_two_wakers = false; + } + Poll::Ready(None) => { + *self.as_mut().is_stream_done() = true; + // Polling futures in current iteration with the same context + // is ok because we already received `Poll::Ready` from + // stream + polling_with_two_wakers = false; + } + Poll::Pending => { + if !polling_with_two_wakers { + need_to_poll_next |= NEED_TO_POLL_STREAM; + } + } + } + } else { + need_to_poll_next |= NEED_TO_POLL_STREAM; + } + } + + if poll_state_value & NEED_TO_POLL_FUTURES != NONE { + polled_futures = true; + match if polling_with_two_wakers { + let waker = self.create_poll_futures_waker(ctx); + let mut ctx = Context::from_waker(&waker); + self.as_mut().futures().poll_next(&mut ctx) + } else { + self.as_mut().futures().poll_next(ctx) + } { + Poll::Ready(Some(Some((item, stream)))) => { + self.as_mut().futures().push(StreamFut { + stream: Some(stream), + }); + next_item = Some(item); + need_to_poll_next |= NEED_TO_POLL_FUTURES; + } + Poll::Ready(Some(None)) => { + need_to_poll_next |= NEED_TO_POLL_FUTURES; + } + Poll::Pending => { + if !polling_with_two_wakers { + need_to_poll_next |= NEED_TO_POLL_FUTURES; + } + } + _ => { + need_to_poll_next &= !NEED_TO_POLL_FUTURES; + } + } + } + + let poll_state_value = self.as_mut().poll_state().end_polling(need_to_poll_next); + + if poll_state_value & NEED_TO_POLL != NONE { + if !polling_with_two_wakers { + if poll_state_value & NEED_TO_POLL_FUTURES != NONE && !polled_futures + || poll_state_value & NEED_TO_POLL_STREAM != NONE && !polled_stream + { + ctx.waker().wake_by_ref(); + } + } else { + ctx.waker().wake_by_ref(); + } + } + + if self.futures.is_empty() && self.is_stream_done || next_item.is_some() { + Poll::Ready(next_item) + } else { + Poll::Pending + } + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for FlatMapUnordered +where + S: Stream + Sink, + U: Stream, + F: FnMut(S::Item) -> U, +{ + type Error = S::Error; + + delegate_sink!(stream, Item); +} diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index 4f227326f0..a2d0a68cda 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -47,6 +47,10 @@ mod flatten; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::flatten::Flatten; +mod flat_map_unordered; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::flat_map_unordered::FlatMapUnordered; + mod fold; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::fold::Fold; @@ -544,6 +548,52 @@ pub trait StreamExt: Stream { Flatten::new(self) } + /// Maps a stream like [`StreamExt::map`] but flattens nested `Stream`s + /// and polls them concurrently, yielding items in any order, as they made + /// available. + /// + /// [`StreamExt::map`] is very useful, but if it produces `Stream`s + /// instead, and you need to poll all of them concurrently, you would + /// have to use something like `for_each_concurrent` and merge values + /// by hand. This combinator provides ability to collect all values + /// from concurrently polled streams into one stream. + /// + /// The first argument is an optional limit on the number of concurrently + /// polled streams. If this limit is not `None`, no more than `limit` streams + /// will be polled concurrently. The `limit` argument is of type + /// `Into>`, and so can be provided as either `None`, + /// `Some(10)`, or just `10`. Note: a limit of zero is interpreted as + /// no limit at all, and will have the same result as passing in `None`. + /// + /// The provided closure which produce inner streams is executed over + /// all elements of stream as next stream item is available and limit + /// of concurrently processed streams isn't exceeded. + /// + /// Note that this function consumes the stream passed into it and + /// returns a wrapped version of it. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::stream::{self, StreamExt}; + /// + /// let stream = stream::iter(1..5); + /// let stream = stream.flat_map_unordered(1, |x| stream::iter(vec![x; x])); + /// let mut values = stream.collect::>().await; + /// values.sort(); + /// + /// assert_eq!(vec![1usize, 2, 2, 3, 3, 3, 4, 4, 4, 4], values); + /// # }); + fn flat_map_unordered(self, limit: impl Into>, f: F) -> FlatMapUnordered + where + U: Stream, + F: FnMut(Self::Item) -> U, + Self: Sized, + { + FlatMapUnordered::new(self, limit.into(), f) + } + /// Combinator similar to [`StreamExt::fold`] that holds internal state /// and produces a new stream. /// diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 54f49c668d..bf8bb9adbe 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -1,5 +1,5 @@ use futures::executor::block_on; -use futures::stream::{self, StreamExt}; +use futures::stream::{self, *}; #[test] fn select() { @@ -30,3 +30,21 @@ fn scan() { ); }); } + +#[test] +fn flat_map_unordered() { + futures::executor::block_on(async { + let st = stream::iter(vec![ + stream::iter(0..=4u8), + stream::iter(6..=10), + stream::iter(0..=2), + ]); + + let mut fm_unordered = st + .flat_map_unordered(1, |s| s.filter(|v| futures::future::ready(v % 2 == 0))) + .collect::>() + .await; + + assert_eq!(fm_unordered.sort(), vec![0, 2, 4, 6, 8, 10, 0, 2].sort()); + }); +}