diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs new file mode 100644 index 00000000000..8fbd50f3513 --- /dev/null +++ b/tokio-util/src/io/inspect.rs @@ -0,0 +1,46 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::{ + io::Result, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, ReadBuf}; + +pin_project! { + /// An adapter that lets you inspect the data that's being read. + /// + /// This is useful for things like hashing data as it's read in. + pub struct InspectReader { + #[pin] + reader: R, + f: F, + } +} + +impl InspectReader { + /// Create a new InspectReader, wrapping `reader` and calling `f` for the + /// new data supplied by each read call. + pub fn new(reader: R, f: F) -> InspectReader { + InspectReader { reader, f } + } + + /// Consumes the `InspectReader`, returning the wrapped reader + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for InspectReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + let filled_length = buf.filled().len(); + ready!(me.reader.poll_read(cx, buf))?; + (me.f)(&buf.filled()[filled_length..]); + Poll::Ready(Ok(())) + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index eb48a21fb98..df0dd02fdc0 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -10,14 +10,17 @@ //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead +mod inspect; mod read_buf; mod reader_stream; mod stream_reader; + cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } +pub use self::inspect::InspectReader; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs new file mode 100644 index 00000000000..aa907c0c148 --- /dev/null +++ b/tokio-util/tests/io_inspect.rs @@ -0,0 +1,43 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio_util::io::InspectReader; + +/// An AsyncRead implementation that works byte-by-byte, to catch out callers +/// who don't allow for `buf` being part-filled before the call +struct SmallReader { + contents: Vec, +} + +impl Unpin for SmallReader {} + +impl AsyncRead for SmallReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(byte) = self.contents.pop() { + buf.put_slice(&[byte]) + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_tee() { + let contents = b"This could be really long, you know".to_vec(); + let reader = SmallReader { + contents: contents.clone(), + }; + let mut altout: Vec = Vec::new(); + let mut teeout = Vec::new(); + { + let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); + tee.read_to_end(&mut teeout).await.unwrap(); + } + assert_eq!(teeout, altout); + assert_eq!(altout.len(), contents.len()); +}