Skip to content

Commit

Permalink
tokio_util: add inspection wrapper for AsyncRead
Browse files Browse the repository at this point in the history
There are use cases like checking hashes of files that benefit from
being able to inspect bytes read as they come in, while still letting
the main code process the bytes as normal (e.g. deserializing into
objects, knowing that if there's a hash failure, you'll discard the
result).

As this is non-trivial to get right (e.g. handling a `buf` that's not
empty when passed to `poll_read`, add a wrapper `InspectReader`
that gets this right, passing all newly read bytes to a supplied `FnMut`
closure.

Fixes: tokio-rs#4584
  • Loading branch information
farnz committed Sep 20, 2022
1 parent d69e5be commit a0a9051
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tokio-util/src/io/inspect.rs
@@ -0,0 +1,46 @@
use futures_core::ready;
use pin_project_lite::pin_project;
use std::{
io::Result,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, ReadBuf};

pin_project! {
/// An adapter that lets you inspect the data that's being read.
///
/// This is useful for things like hashing data as it's read in.
pub struct InspectReader<R: AsyncRead, F: FnMut(&[u8])> {
#[pin]
reader: R,
f: F,
}
}

impl<R: AsyncRead, F: FnMut(&[u8])> InspectReader<R, F> {
/// Create a new InspectReader, wrapping `reader` and calling `f` for the
/// new data supplied by each read call.
pub fn new(reader: R, f: F) -> InspectReader<R, F> {
InspectReader { reader, f }
}

/// Consumes the `InspectReader`, returning the wrapped reader
pub fn into_inner(self) -> R {
self.reader
}
}

impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let me = self.project();
let filled_length = buf.filled().len();
ready!(me.reader.poll_read(cx, buf))?;
(me.f)(&buf.filled()[filled_length..]);
Poll::Ready(Ok(()))
}
}
3 changes: 3 additions & 0 deletions tokio-util/src/io/mod.rs
Expand Up @@ -10,14 +10,17 @@
//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
//! [`AsyncRead`]: tokio::io::AsyncRead

mod inspect;
mod read_buf;
mod reader_stream;
mod stream_reader;

cfg_io_util! {
mod sync_bridge;
pub use self::sync_bridge::SyncIoBridge;
}

pub use self::inspect::InspectReader;
pub use self::read_buf::read_buf;
pub use self::reader_stream::ReaderStream;
pub use self::stream_reader::StreamReader;
Expand Down
43 changes: 43 additions & 0 deletions tokio-util/tests/io_inspect.rs
@@ -0,0 +1,43 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
use tokio_util::io::InspectReader;

/// An AsyncRead implementation that works byte-by-byte, to catch out callers
/// who don't allow for `buf` being part-filled before the call
struct SmallReader {
contents: Vec<u8>,
}

impl Unpin for SmallReader {}

impl AsyncRead for SmallReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if let Some(byte) = self.contents.pop() {
buf.put_slice(&[byte])
}
Poll::Ready(Ok(()))
}
}

#[tokio::test]
async fn read_tee() {
let contents = b"This could be really long, you know".to_vec();
let reader = SmallReader {
contents: contents.clone(),
};
let mut altout: Vec<u8> = Vec::new();
let mut teeout = Vec::new();
{
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
tee.read_to_end(&mut teeout).await.unwrap();
}
assert_eq!(teeout, altout);
assert_eq!(altout.len(), contents.len());
}

0 comments on commit a0a9051

Please sign in to comment.