Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StreamExt::scan #2044

Merged
merged 4 commits into from Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion futures-util/src/stream/mod.rs
Expand Up @@ -14,7 +14,7 @@ pub use futures_core::stream::{FusedStream, Stream, TryStream};
mod stream;
pub use self::stream::{
Chain, Collect, Concat, Enumerate, Filter, FilterMap, Flatten, Fold, ForEach, Fuse, Inspect,
Map, Next, Peek, Peekable, SelectNextSome, Skip, SkipWhile, StreamExt, StreamFuture, Take,
Map, Next, Peek, Peekable, Scan, SelectNextSome, Skip, SkipWhile, StreamExt, StreamFuture, Take,
TakeWhile, Then, Zip,
};

Expand Down
35 changes: 35 additions & 0 deletions futures-util/src/stream/stream/mod.rs
Expand Up @@ -120,6 +120,10 @@ mod chunks;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
pub use self::chunks::Chunks;

mod scan;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
pub use self::scan::Scan;

cfg_target_has_atomic! {
#[cfg(feature = "alloc")]
mod buffer_unordered;
Expand Down Expand Up @@ -540,6 +544,37 @@ pub trait StreamExt: Stream {
Flatten::new(self)
}

/// Combinator similar to [`StreamExt::fold`] that holds internal state and produces a new stream.
///
/// Accepts initial state and closure which will be applied to each element of the stream until provided closure
/// returns `None`. Once `None` is returned, stream will be terminated.
///
/// # Examples
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future;
/// use futures::stream::{self, StreamExt};
///
/// let stream = stream::iter(1..=10);
///
/// let stream = stream.scan(0, |state, x| {
/// *state += x;
/// future::ready(if *state < 10 { Some(x) } else { None })
/// });
///
/// assert_eq!(vec![1, 2, 3], stream.collect::<Vec<_>>().await);
/// # });
/// ```
fn scan<S, B, Fut, F>(self, initial_state: S, f: F) -> Scan<Self, S, Fut, F>
where
F: FnMut(&mut S, Self::Item) -> Fut,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a way to fix this at the moment, sadly, but this is subtly the wrong method signature for this function, because it doesn't allow Fut to hold onto &mut S. There's no way to fix this right now due to limitations of HRTBs / GATs, so I don't mind accepting this PR now, but note that this makes this API dramatically less useful than it could be otherwise.

Fut: Future<Output = Option<B>>,
Self: Sized,
{
Scan::new(self, initial_state, f)
}

/// Skip elements on this stream while the provided asynchronous predicate
/// resolves to `true`.
///
Expand Down
165 changes: 165 additions & 0 deletions futures-util/src/stream/stream/scan.rs
@@ -0,0 +1,165 @@
use core::fmt;
use core::pin::Pin;
use futures_core::future::Future;
use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
#[cfg(feature = "sink")]
use futures_sink::Sink;
use pin_utils::{unsafe_pinned, unsafe_unpinned};

struct StateFn<S, F> {
state: S,
f: F,
}

/// Stream for the [`scan`](super::StreamExt::scan) method.
#[must_use = "streams do nothing unless polled"]
pub struct Scan<St: Stream, S, Fut, F> {
stream: St,
state_f: Option<StateFn<S, F>>,
future: Option<Fut>,
}

impl<St: Unpin + Stream, S, Fut: Unpin, F> Unpin for Scan<St, S, Fut, F> {}

impl<St, S, Fut, F> fmt::Debug for Scan<St, S, Fut, F>
where
St: Stream + fmt::Debug,
St::Item: fmt::Debug,
S: fmt::Debug,
Fut: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Scan")
.field("stream", &self.stream)
.field("state", &self.state_f.as_ref().map(|s| &s.state))
.field("future", &self.future)
.field("done_taking", &self.is_done_taking())
.finish()
}
}

impl<St: Stream, S, Fut, F> Scan<St, S, Fut, F> {
unsafe_pinned!(stream: St);
unsafe_unpinned!(state_f: Option<StateFn<S, F>>);
unsafe_pinned!(future: Option<Fut>);

/// Checks if internal state is `None`.
fn is_done_taking(&self) -> bool {
self.state_f.is_none()
}
}

impl<B, St, S, Fut, F> Scan<St, S, Fut, F>
where
St: Stream,
F: FnMut(&mut S, St::Item) -> Fut,
Fut: Future<Output = Option<B>>,
{
pub(super) fn new(stream: St, initial_state: S, f: F) -> Scan<St, S, Fut, F> {
Scan {
stream,
state_f: Some(StateFn {
state: initial_state,
f,
}),
future: None,
}
}

/// Acquires a reference to the underlying stream that this combinator is
/// pulling from.
pub fn get_ref(&self) -> &St {
&self.stream
}

/// Acquires a mutable reference to the underlying stream that this
/// combinator is pulling from.
///
/// Note that care must be taken to avoid tampering with the state of the
/// stream which may otherwise confuse this combinator.
pub fn get_mut(&mut self) -> &mut St {
&mut self.stream
}

/// Acquires a pinned mutable reference to the underlying stream that this
/// combinator is pulling from.
///
/// Note that care must be taken to avoid tampering with the state of the
/// stream which may otherwise confuse this combinator.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut St> {
self.stream()
}

/// Consumes this combinator, returning the underlying stream.
///
/// Note that this may discard intermediate state of this combinator, so
/// care should be taken to avoid losing resources when this is called.
pub fn into_inner(self) -> St {
self.stream
}
}

impl<B, St, S, Fut, F> Stream for Scan<St, S, Fut, F>
where
St: Stream,
F: FnMut(&mut S, St::Item) -> Fut,
Fut: Future<Output = Option<B>>,
{
type Item = B;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<B>> {
if self.is_done_taking() {
return Poll::Ready(None);
}

if self.future.is_none() {
let item = match ready!(self.as_mut().stream().poll_next(cx)) {
Some(e) => e,
None => return Poll::Ready(None),
};
let state_f = self.as_mut().state_f().as_mut().unwrap();
let fut = (state_f.f)(&mut state_f.state, item);
self.as_mut().future().set(Some(fut));
}

let item = ready!(self.as_mut().future().as_pin_mut().unwrap().poll(cx));
self.as_mut().future().set(None);

if item.is_none() {
self.as_mut().state_f().take();
}

Poll::Ready(item)
}

fn size_hint(&self) -> (usize, Option<usize>) {
if self.is_done_taking() {
(0, Some(0))
} else {
self.stream.size_hint() // can't know a lower bound, due to the predicate
}
}
}

impl<B, St, S, Fut, F> FusedStream for Scan<St, S, Fut, F>
where
St: FusedStream,
F: FnMut(&mut S, St::Item) -> Fut,
Fut: Future<Output = Option<B>>,
{
fn is_terminated(&self) -> bool {
self.is_done_taking() || self.future.is_none() && self.stream.is_terminated()
}
}

// Forwarding impl of Sink from the underlying stream
#[cfg(feature = "sink")]
impl<S, Fut, F, Item> Sink<Item> for Scan<S, S, Fut, F>
where
S: Stream + Sink<Item>,
{
type Error = S::Error;

delegate_sink!(stream, Item);
}
2 changes: 1 addition & 1 deletion futures/src/lib.rs
Expand Up @@ -444,7 +444,7 @@ pub mod stream {
StreamExt,
Chain, Collect, Concat, Enumerate, Filter, FilterMap, Flatten, Fold,
Forward, ForEach, Fuse, StreamFuture, Inspect, Map, Next,
SelectNextSome, Peek, Peekable, Skip, SkipWhile, Take, TakeWhile,
SelectNextSome, Peek, Peekable, Scan, Skip, SkipWhile, Take, TakeWhile,
Then, Zip,

TryStreamExt,
Expand Down
16 changes: 16 additions & 0 deletions futures/tests/stream.rs
Expand Up @@ -14,3 +14,19 @@ fn select() {
select_and_compare(vec![1, 2, 3], vec![4, 5], vec![1, 4, 2, 5, 3]);
select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]);
}

#[test]
fn scan() {
futures::executor::block_on(async {
assert_eq!(
stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2])
.scan(1, |acc, e| {
*acc += 1;
futures::future::ready(if e < *acc { Some(e) } else { None })
})
.collect::<Vec<_>>()
.await,
vec![1u8, 2, 3, 4]
);
});
}