diff --git a/futures-util/src/stream/mod.rs b/futures-util/src/stream/mod.rs index a352331bd4..ba3575bf02 100644 --- a/futures-util/src/stream/mod.rs +++ b/futures-util/src/stream/mod.rs @@ -14,7 +14,7 @@ pub use futures_core::stream::{FusedStream, Stream, TryStream}; mod stream; pub use self::stream::{ Chain, Collect, Concat, Enumerate, Filter, FilterMap, Flatten, Fold, ForEach, Fuse, Inspect, - Map, Next, Peek, Peekable, SelectNextSome, Skip, SkipWhile, StreamExt, StreamFuture, Take, + Map, Next, Peek, Peekable, Scan, SelectNextSome, Skip, SkipWhile, StreamExt, StreamFuture, Take, TakeWhile, Then, Zip, }; diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index 44726ca0c7..da5ade85bb 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -120,6 +120,10 @@ mod chunks; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::chunks::Chunks; +mod scan; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::scan::Scan; + cfg_target_has_atomic! { #[cfg(feature = "alloc")] mod buffer_unordered; @@ -540,6 +544,37 @@ pub trait StreamExt: Stream { Flatten::new(self) } + /// Combinator similar to [`StreamExt::fold`] that holds internal state and produces a new stream. + /// + /// Accepts initial state and closure which will be applied to each element of the stream until provided closure + /// returns `None`. Once `None` is returned, stream will be terminated. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::future; + /// use futures::stream::{self, StreamExt}; + /// + /// let stream = stream::iter(1..=10); + /// + /// let stream = stream.scan(0, |state, x| { + /// *state += x; + /// future::ready(if *state < 10 { Some(x) } else { None }) + /// }); + /// + /// assert_eq!(vec![1, 2, 3], stream.collect::>().await); + /// # }); + /// ``` + fn scan(self, initial_state: S, f: F) -> Scan + where + F: FnMut(&mut S, Self::Item) -> Fut, + Fut: Future>, + Self: Sized, + { + Scan::new(self, initial_state, f) + } + /// Skip elements on this stream while the provided asynchronous predicate /// resolves to `true`. /// diff --git a/futures-util/src/stream/stream/scan.rs b/futures-util/src/stream/stream/scan.rs new file mode 100644 index 0000000000..4f937f4fd9 --- /dev/null +++ b/futures-util/src/stream/stream/scan.rs @@ -0,0 +1,165 @@ +use core::fmt; +use core::pin::Pin; +use futures_core::future::Future; +use futures_core::stream::{FusedStream, Stream}; +use futures_core::task::{Context, Poll}; +#[cfg(feature = "sink")] +use futures_sink::Sink; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; + +struct StateFn { + state: S, + f: F, +} + +/// Stream for the [`scan`](super::StreamExt::scan) method. +#[must_use = "streams do nothing unless polled"] +pub struct Scan { + stream: St, + state_f: Option>, + future: Option, +} + +impl Unpin for Scan {} + +impl fmt::Debug for Scan +where + St: Stream + fmt::Debug, + St::Item: fmt::Debug, + S: fmt::Debug, + Fut: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scan") + .field("stream", &self.stream) + .field("state", &self.state_f.as_ref().map(|s| &s.state)) + .field("future", &self.future) + .field("done_taking", &self.is_done_taking()) + .finish() + } +} + +impl Scan { + unsafe_pinned!(stream: St); + unsafe_unpinned!(state_f: Option>); + unsafe_pinned!(future: Option); + + /// Checks if internal state is `None`. + fn is_done_taking(&self) -> bool { + self.state_f.is_none() + } +} + +impl Scan +where + St: Stream, + F: FnMut(&mut S, St::Item) -> Fut, + Fut: Future>, +{ + pub(super) fn new(stream: St, initial_state: S, f: F) -> Scan { + Scan { + stream, + state_f: Some(StateFn { + state: initial_state, + f, + }), + future: None, + } + } + + /// Acquires a reference to the underlying stream that this combinator is + /// pulling from. + pub fn get_ref(&self) -> &St { + &self.stream + } + + /// Acquires a mutable reference to the underlying stream that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_mut(&mut self) -> &mut St { + &mut self.stream + } + + /// Acquires a pinned mutable reference to the underlying stream that this + /// combinator is pulling from. + /// + /// Note that care must be taken to avoid tampering with the state of the + /// stream which may otherwise confuse this combinator. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut St> { + self.stream() + } + + /// Consumes this combinator, returning the underlying stream. + /// + /// Note that this may discard intermediate state of this combinator, so + /// care should be taken to avoid losing resources when this is called. + pub fn into_inner(self) -> St { + self.stream + } +} + +impl Stream for Scan +where + St: Stream, + F: FnMut(&mut S, St::Item) -> Fut, + Fut: Future>, +{ + type Item = B; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_done_taking() { + return Poll::Ready(None); + } + + if self.future.is_none() { + let item = match ready!(self.as_mut().stream().poll_next(cx)) { + Some(e) => e, + None => return Poll::Ready(None), + }; + let state_f = self.as_mut().state_f().as_mut().unwrap(); + let fut = (state_f.f)(&mut state_f.state, item); + self.as_mut().future().set(Some(fut)); + } + + let item = ready!(self.as_mut().future().as_pin_mut().unwrap().poll(cx)); + self.as_mut().future().set(None); + + if item.is_none() { + self.as_mut().state_f().take(); + } + + Poll::Ready(item) + } + + fn size_hint(&self) -> (usize, Option) { + if self.is_done_taking() { + (0, Some(0)) + } else { + self.stream.size_hint() // can't know a lower bound, due to the predicate + } + } +} + +impl FusedStream for Scan +where + St: FusedStream, + F: FnMut(&mut S, St::Item) -> Fut, + Fut: Future>, +{ + fn is_terminated(&self) -> bool { + self.is_done_taking() || self.future.is_none() && self.stream.is_terminated() + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for Scan +where + S: Stream + Sink, +{ + type Error = S::Error; + + delegate_sink!(stream, Item); +} diff --git a/futures/src/lib.rs b/futures/src/lib.rs index fa11cb9412..d4f248257a 100644 --- a/futures/src/lib.rs +++ b/futures/src/lib.rs @@ -444,7 +444,7 @@ pub mod stream { StreamExt, Chain, Collect, Concat, Enumerate, Filter, FilterMap, Flatten, Fold, Forward, ForEach, Fuse, StreamFuture, Inspect, Map, Next, - SelectNextSome, Peek, Peekable, Skip, SkipWhile, Take, TakeWhile, + SelectNextSome, Peek, Peekable, Scan, Skip, SkipWhile, Take, TakeWhile, Then, Zip, TryStreamExt, diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 2ed2f418ac..fd6a8b6da7 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -14,3 +14,19 @@ fn select() { select_and_compare(vec![1, 2, 3], vec![4, 5], vec![1, 4, 2, 5, 3]); select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]); } + +#[test] +fn scan() { + futures::executor::block_on(async { + assert_eq!( + stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2]) + .scan(1, |acc, e| { + *acc += 1; + futures::future::ready(if e < *acc { Some(e) } else { None }) + }) + .collect::>() + .await, + vec![1u8, 2, 3, 4] + ); + }); +}