diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index da5ade85bb..3a24d38428 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -566,10 +566,10 @@ pub trait StreamExt: Stream { /// assert_eq!(vec![1, 2, 3], stream.collect::>().await); /// # }); /// ``` - fn scan(self, initial_state: S, f: F) -> Scan + fn scan<'a, S: 'a, B, Fut, F>(self, initial_state: S, f: F) -> Scan where - F: FnMut(&mut S, Self::Item) -> Fut, - Fut: Future>, + F: FnMut(&'a mut S, Self::Item) -> Fut, + Fut: Future> + 'a, Self: Sized, { Scan::new(self, initial_state, f) diff --git a/futures-util/src/stream/stream/scan.rs b/futures-util/src/stream/stream/scan.rs index 4f937f4fd9..ba6b424d3f 100644 --- a/futures-util/src/stream/stream/scan.rs +++ b/futures-util/src/stream/stream/scan.rs @@ -6,6 +6,7 @@ use futures_core::task::{Context, Poll}; #[cfg(feature = "sink")] use futures_sink::Sink; use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use core::mem::transmute; struct StateFn { state: S, @@ -25,7 +26,6 @@ impl Unpin for Scan {} impl fmt::Debug for Scan where St: Stream + fmt::Debug, - St::Item: fmt::Debug, S: fmt::Debug, Fut: fmt::Debug, { @@ -50,11 +50,11 @@ impl Scan { } } -impl Scan +impl<'a, B, St, S: 'a, Fut, F> Scan where St: Stream, - F: FnMut(&mut S, St::Item) -> Fut, - Fut: Future>, + F: FnMut(&'a mut S, St::Item) -> Fut, + Fut: Future> + 'a, { pub(super) fn new(stream: St, initial_state: S, f: F) -> Scan { Scan { @@ -100,11 +100,11 @@ where } } -impl Stream for Scan +impl<'a, B, St, S: 'a, Fut, F> Stream for Scan where St: Stream, - F: FnMut(&mut S, St::Item) -> Fut, - Fut: Future>, + F: FnMut(&'a mut S, St::Item) -> Fut, + Fut: Future> + 'a, { type Item = B; @@ -119,7 +119,12 @@ where None => return Poll::Ready(None), }; let state_f = self.as_mut().state_f().as_mut().unwrap(); - let fut = (state_f.f)(&mut state_f.state, item); + let fut = (state_f.f)( + // Safety: this's safe because state is internal and can only be accessed + // via provided function. + unsafe { transmute(&mut state_f.state) }, + item, + ); self.as_mut().future().set(Some(fut)); } @@ -142,11 +147,11 @@ where } } -impl FusedStream for Scan +impl<'a, B, St, S: 'a, Fut, F> FusedStream for Scan where St: FusedStream, - F: FnMut(&mut S, St::Item) -> Fut, - Fut: Future>, + F: FnMut(&'a mut S, St::Item) -> Fut, + Fut: Future> + 'a, { fn is_terminated(&self) -> bool { self.is_done_taking() || self.future.is_none() && self.stream.is_terminated() diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index fd6a8b6da7..38a85fe403 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -1,4 +1,5 @@ use futures::executor::block_on; +use futures::future::*; use futures::stream::{self, StreamExt}; #[test] @@ -15,14 +16,34 @@ fn select() { select_and_compare(vec![1, 2], vec![4, 5, 6], vec![1, 4, 2, 5, 6]); } +async fn async_scan_fn(state: &mut u8, e: u8) -> Option { + *state += 1; + if e < *state { + Some(e) + } else { + None + } +} + +fn impl_scan_fn<'a>(state: &'a mut u8, e: u8) -> impl Future> + 'a { + async move { + *state += 1; + if e < *state { + Some(e) + } else { + None + } + } +} + #[test] fn scan() { futures::executor::block_on(async { assert_eq!( stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2]) - .scan(1, |acc, e| { - *acc += 1; - futures::future::ready(if e < *acc { Some(e) } else { None }) + .scan(1, |state, e| { + *state += 1; + futures::future::ready(if e < *state { Some(e) } else { None }) }) .collect::>() .await, @@ -30,3 +51,29 @@ fn scan() { ); }); } + +#[test] +fn scan_with_async() { + futures::executor::block_on(async { + assert_eq!( + stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2]) + .scan(1u8, async_scan_fn) + .collect::>() + .await, + vec![1u8, 2, 3, 4] + ); + }); +} + +#[test] +fn scan_with_impl() { + futures::executor::block_on(async { + assert_eq!( + stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2]) + .scan(1u8, impl_scan_fn) + .collect::>() + .await, + vec![1u8, 2, 3, 4] + ); + }); +}