diff --git a/Cargo.toml b/Cargo.toml index 040ad970..93550b42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ console = { version = ">=0.9.1, <1.0.0", default-features = false } unicode-segmentation = { version = "1.6.0", optional = true } unicode-width = { version = "0.1.7", optional = true } rayon = { version = "1.0", optional = true } +tokio = { version = "1.0", optional = true, features = ["fs", "io-util"] } [dev-dependencies] rand = "0.8" diff --git a/src/iter.rs b/src/iter.rs index 31b5c8a7..c4e8d160 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -2,6 +2,13 @@ use crate::progress_bar::ProgressBar; use std::convert::TryFrom; use std::io::{self, IoSliceMut}; use std::iter::FusedIterator; +#[cfg(feature = "tokio")] +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +#[cfg(feature = "tokio")] +use tokio::io::{ReadBuf, SeekFrom}; /// Wraps an iterator to display its progress. pub trait ProgressIterator @@ -141,6 +148,58 @@ impl io::Seek for ProgressBarIter { } } +#[cfg(feature = "tokio")] +impl tokio::io::AsyncWrite for ProgressBarIter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.it).poll_write(cx, buf).map(|poll| { + poll.map(|inc| { + self.progress.inc(inc as u64); + inc + }) + }) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.it).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.it).poll_shutdown(cx) + } +} + +#[cfg(feature = "tokio")] +impl tokio::io::AsyncRead for ProgressBarIter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let prev_len = buf.filled().len() as u64; + if let Poll::Ready(e) = Pin::new(&mut self.it).poll_read(cx, buf) { + self.progress.inc(buf.filled().len() as u64 - prev_len); + Poll::Ready(e) + } else { + Poll::Pending + } + } +} + +#[cfg(feature = "tokio")] +impl tokio::io::AsyncSeek for ProgressBarIter { + fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> { + Pin::new(&mut self.it).start_seek(position) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.it).poll_complete(cx) + } +} + impl io::Write for ProgressBarIter { fn write(&mut self, buf: &[u8]) -> io::Result { self.it.write(buf).map(|inc| { diff --git a/src/progress_bar.rs b/src/progress_bar.rs index 34e82ff3..c8eca743 100644 --- a/src/progress_bar.rs +++ b/src/progress_bar.rs @@ -457,6 +457,52 @@ impl ProgressBar { } } + #[cfg(feature = "tokio")] + /// Wraps an [`tokio::io::AsyncWrite`] with the progress bar + /// + /// ```rust,no_run + /// # use tokio::fs::File; + /// # use tokio::io; + /// # use indicatif::ProgressBar; + /// # async fn test() -> io::Result<()> { + /// let mut source = File::open("work.txt").await?; + /// let mut target = File::open("done.txt").await?; + /// let pb = ProgressBar::new(source.metadata().await?.len()); + /// io::copy(&mut source, &mut pb.wrap_async_write(target)).await?; + /// # Ok(()) + /// # } + /// ``` + pub fn wrap_async_write( + &self, + write: W, + ) -> ProgressBarIter { + ProgressBarIter { + progress: self.clone(), + it: write, + } + } + #[cfg(feature = "tokio")] + /// Wraps an [`tokio::io::AsyncRead`] with the progress bar + /// + /// ```rust,no_run + /// # use tokio::fs::File; + /// # use tokio::io; + /// # use indicatif::ProgressBar; + /// # async fn test() -> io::Result<()> { + /// let mut source = File::open("work.txt").await?; + /// let mut target = File::open("done.txt").await?; + /// let pb = ProgressBar::new(source.metadata().await?.len()); + /// io::copy(&mut pb.wrap_async_read(source), &mut target).await?; + /// # Ok(()) + /// # } + /// ``` + pub fn wrap_async_read(&self, write: W) -> ProgressBarIter { + ProgressBarIter { + progress: self.clone(), + it: write, + } + } + fn update_and_draw(&self, f: F) { // Delegate to the wrapped state. let mut state = self.state.lock().unwrap();