diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index 7f950102635..5927dbd22ab 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -61,6 +61,8 @@ cfg_time! { use tokio::time::Duration; mod throttle; use throttle::{throttle, Throttle}; + mod chunks_timeout; + use chunks_timeout::ChunksTimeout; } /// An extension trait for the [`Stream`] trait that provides a variety of @@ -1005,6 +1007,62 @@ pub trait StreamExt: Stream { { throttle(duration, self) } + + /// Batches the items in the given stream using a maximum duration and size for each batch. + /// + /// This stream returns the next batch of items in the following situations: + /// 1. The inner stream has returned at least `max_size` many items since the last batch. + /// 2. The time since the first item of a batch is greater than the given duration. + /// 3. The end of the stream is reached. + /// + /// The length of the returned vector is never empty or greater than the maximum size. Empty batches + /// will not be emitted if no items are received upstream. + /// + /// # Panics + /// + /// This function panics if `max_size` is zero + /// + /// # Example + /// + /// ```rust + /// use std::time::Duration; + /// use tokio::time; + /// use tokio_stream::{self as stream, StreamExt}; + /// use futures::FutureExt; + /// + /// #[tokio::main] + /// # async fn _unused() {} + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// let iter = vec![1, 2, 3, 4].into_iter(); + /// let stream0 = stream::iter(iter); + /// + /// let iter = vec![5].into_iter(); + /// let stream1 = stream::iter(iter) + /// .then(move |n| time::sleep(Duration::from_secs(5)).map(move |_| n)); + /// + /// let chunk_stream = stream0 + /// .chain(stream1) + /// .chunks_timeout(3, Duration::from_secs(2)); + /// tokio::pin!(chunk_stream); + /// + /// // a full batch was received + /// assert_eq!(chunk_stream.next().await, Some(vec![1,2,3])); + /// // deadline was reached before max_size was reached + /// assert_eq!(chunk_stream.next().await, Some(vec![4])); + /// // last element in the stream + /// assert_eq!(chunk_stream.next().await, Some(vec![5])); + /// } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn chunks_timeout(self, max_size: usize, duration: Duration) -> ChunksTimeout + where + Self: Sized, + { + assert!(max_size > 0, "`max_size` must be non-zero."); + ChunksTimeout::new(self, max_size, duration) + } } impl StreamExt for St where St: Stream {} diff --git a/tokio-stream/src/stream_ext/chunks_timeout.rs b/tokio-stream/src/stream_ext/chunks_timeout.rs new file mode 100644 index 00000000000..107101317a3 --- /dev/null +++ b/tokio-stream/src/stream_ext/chunks_timeout.rs @@ -0,0 +1,84 @@ +use crate::stream_ext::Fuse; +use crate::Stream; +use tokio::time::{sleep, Instant, Sleep}; + +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; +use std::time::Duration; + +pin_project! { + /// Stream returned by the [`chunks_timeout`](super::StreamExt::chunks_timeout) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct ChunksTimeout { + #[pin] + stream: Fuse, + #[pin] + deadline: Sleep, + duration: Duration, + items: Vec, + cap: usize, // https://github.com/rust-lang/futures-rs/issues/1475 + } +} + +impl ChunksTimeout { + pub(super) fn new(stream: S, max_size: usize, duration: Duration) -> Self { + ChunksTimeout { + stream: Fuse::new(stream), + deadline: sleep(duration), + duration, + items: Vec::with_capacity(max_size), + cap: max_size, + } + } +} + +impl Stream for ChunksTimeout { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.as_mut().project(); + loop { + match me.stream.as_mut().poll_next(cx) { + Poll::Pending => break, + Poll::Ready(Some(item)) => { + if me.items.is_empty() { + me.deadline.as_mut().reset(Instant::now() + *me.duration); + me.items.reserve_exact(*me.cap); + } + me.items.push(item); + if me.items.len() >= *me.cap { + return Poll::Ready(Some(std::mem::take(me.items))); + } + } + Poll::Ready(None) => { + // Returning Some here is only correct because we fuse the inner stream. + let last = if me.items.is_empty() { + None + } else { + Some(std::mem::take(me.items)) + }; + + return Poll::Ready(last); + } + } + } + + if !me.items.is_empty() { + ready!(me.deadline.poll(cx)); + return Poll::Ready(Some(std::mem::take(me.items))); + } + + Poll::Pending + } + + fn size_hint(&self) -> (usize, Option) { + let chunk_len = if self.items.is_empty() { 0 } else { 1 }; + let (lower, upper) = self.stream.size_hint(); + let lower = (lower / self.cap).saturating_add(chunk_len); + let upper = upper.and_then(|x| x.checked_add(chunk_len)); + (lower, upper) + } +} diff --git a/tokio-stream/tests/chunks_timeout.rs b/tokio-stream/tests/chunks_timeout.rs new file mode 100644 index 00000000000..ffc7deadd70 --- /dev/null +++ b/tokio-stream/tests/chunks_timeout.rs @@ -0,0 +1,84 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "time", feature = "sync", feature = "io-util"))] + +use tokio::time; +use tokio_stream::{self as stream, StreamExt}; +use tokio_test::assert_pending; +use tokio_test::task; + +use futures::FutureExt; +use std::time::Duration; + +#[tokio::test(start_paused = true)] +async fn usage() { + let iter = vec![1, 2, 3].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![4].into_iter(); + let stream1 = + stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(3)).map(move |_| n)); + + let chunk_stream = stream0 + .chain(stream1) + .chunks_timeout(4, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![4])); +} + +#[tokio::test(start_paused = true)] +async fn full_chunk_with_timeout() { + let iter = vec![1, 2].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![3].into_iter(); + let stream1 = + stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(1)).map(move |_| n)); + + let iter = vec![4].into_iter(); + let stream2 = + stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(3)).map(move |_| n)); + + let chunk_stream = stream0 + .chain(stream1) + .chain(stream2) + .chunks_timeout(3, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + + assert_pending!(chunk_stream.poll_next()); + time::advance(Duration::from_secs(2)).await; + assert_eq!(chunk_stream.next().await, Some(vec![4])); +} + +#[tokio::test] +#[ignore] +async fn real_time() { + let iter = vec![1, 2, 3, 4].into_iter(); + let stream0 = stream::iter(iter); + + let iter = vec![5].into_iter(); + let stream1 = + stream::iter(iter).then(move |n| time::sleep(Duration::from_secs(5)).map(move |_| n)); + + let chunk_stream = stream0 + .chain(stream1) + .chunks_timeout(3, Duration::from_secs(2)); + + let mut chunk_stream = task::spawn(chunk_stream); + + assert_eq!(chunk_stream.next().await, Some(vec![1, 2, 3])); + assert_eq!(chunk_stream.next().await, Some(vec![4])); + assert_eq!(chunk_stream.next().await, Some(vec![5])); +}