From a0a9051d5f80a68f66ccf43f0cebd94c1e868077 Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Tue, 20 Sep 2022 10:46:33 +0100 Subject: [PATCH] tokio_util: add inspection wrapper for `AsyncRead` There are use cases like checking hashes of files that benefit from being able to inspect bytes read as they come in, while still letting the main code process the bytes as normal (e.g. deserializing into objects, knowing that if there's a hash failure, you'll discard the result). As this is non-trivial to get right (e.g. handling a `buf` that's not empty when passed to `poll_read`, add a wrapper `InspectReader` that gets this right, passing all newly read bytes to a supplied `FnMut` closure. Fixes: #4584 --- tokio-util/src/io/inspect.rs | 46 ++++++++++++++++++++++++++++++++++ tokio-util/src/io/mod.rs | 3 +++ tokio-util/tests/io_inspect.rs | 43 +++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 tokio-util/src/io/inspect.rs create mode 100644 tokio-util/tests/io_inspect.rs diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs new file mode 100644 index 00000000000..8fbd50f3513 --- /dev/null +++ b/tokio-util/src/io/inspect.rs @@ -0,0 +1,46 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::{ + io::Result, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, ReadBuf}; + +pin_project! { + /// An adapter that lets you inspect the data that's being read. + /// + /// This is useful for things like hashing data as it's read in. + pub struct InspectReader { + #[pin] + reader: R, + f: F, + } +} + +impl InspectReader { + /// Create a new InspectReader, wrapping `reader` and calling `f` for the + /// new data supplied by each read call. + pub fn new(reader: R, f: F) -> InspectReader { + InspectReader { reader, f } + } + + /// Consumes the `InspectReader`, returning the wrapped reader + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for InspectReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + let filled_length = buf.filled().len(); + ready!(me.reader.poll_read(cx, buf))?; + (me.f)(&buf.filled()[filled_length..]); + Poll::Ready(Ok(())) + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index eb48a21fb98..df0dd02fdc0 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -10,14 +10,17 @@ //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead +mod inspect; mod read_buf; mod reader_stream; mod stream_reader; + cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } +pub use self::inspect::InspectReader; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs new file mode 100644 index 00000000000..aa907c0c148 --- /dev/null +++ b/tokio-util/tests/io_inspect.rs @@ -0,0 +1,43 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio_util::io::InspectReader; + +/// An AsyncRead implementation that works byte-by-byte, to catch out callers +/// who don't allow for `buf` being part-filled before the call +struct SmallReader { + contents: Vec, +} + +impl Unpin for SmallReader {} + +impl AsyncRead for SmallReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(byte) = self.contents.pop() { + buf.put_slice(&[byte]) + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_tee() { + let contents = b"This could be really long, you know".to_vec(); + let reader = SmallReader { + contents: contents.clone(), + }; + let mut altout: Vec = Vec::new(); + let mut teeout = Vec::new(); + { + let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); + tee.read_to_end(&mut teeout).await.unwrap(); + } + assert_eq!(teeout, altout); + assert_eq!(altout.len(), contents.len()); +}