forked from tokio-rs/tokio
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tokio_util: add inspection wrapper for
AsyncRead
There are use cases like checking hashes of files that benefit from being able to inspect bytes read as they come in, while still letting the main code process the bytes as normal (e.g. deserializing into objects, knowing that if there's a hash failure, you'll discard the result). As this is non-trivial to get right (e.g. handling a `buf` that's not empty when passed to `poll_read`, add a wrapper `InspectReader` that gets this right, passing all newly read bytes to a supplied `FnMut` closure. Fixes: tokio-rs#4584
- Loading branch information
Showing
3 changed files
with
92 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
use futures_core::ready; | ||
use pin_project_lite::pin_project; | ||
use std::{ | ||
io::Result, | ||
pin::Pin, | ||
task::{Context, Poll}, | ||
}; | ||
use tokio::io::{AsyncRead, ReadBuf}; | ||
|
||
pin_project! { | ||
/// An adapter that lets you inspect the data that's being read. | ||
/// | ||
/// This is useful for things like hashing data as it's read in. | ||
pub struct InspectReader<R: AsyncRead, F: FnMut(&[u8])> { | ||
#[pin] | ||
reader: R, | ||
f: F, | ||
} | ||
} | ||
|
||
impl<R: AsyncRead, F: FnMut(&[u8])> InspectReader<R, F> { | ||
/// Create a new InspectReader, wrapping `reader` and calling `f` for the | ||
/// new data supplied by each read call. | ||
pub fn new(reader: R, f: F) -> InspectReader<R, F> { | ||
InspectReader { reader, f } | ||
} | ||
|
||
/// Consumes the `InspectReader`, returning the wrapped reader | ||
pub fn into_inner(self) -> R { | ||
self.reader | ||
} | ||
} | ||
|
||
impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> { | ||
fn poll_read( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<Result<()>> { | ||
let me = self.project(); | ||
let filled_length = buf.filled().len(); | ||
ready!(me.reader.poll_read(cx, buf))?; | ||
(me.f)(&buf.filled()[filled_length..]); | ||
Poll::Ready(Ok(())) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
use std::{ | ||
pin::Pin, | ||
task::{Context, Poll}, | ||
}; | ||
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; | ||
use tokio_util::io::InspectReader; | ||
|
||
/// An AsyncRead implementation that works byte-by-byte, to catch out callers | ||
/// who don't allow for `buf` being part-filled before the call | ||
struct SmallReader { | ||
contents: Vec<u8>, | ||
} | ||
|
||
impl Unpin for SmallReader {} | ||
|
||
impl AsyncRead for SmallReader { | ||
fn poll_read( | ||
mut self: Pin<&mut Self>, | ||
_cx: &mut Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<std::io::Result<()>> { | ||
if let Some(byte) = self.contents.pop() { | ||
buf.put_slice(&[byte]) | ||
} | ||
Poll::Ready(Ok(())) | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn read_tee() { | ||
let contents = b"This could be really long, you know".to_vec(); | ||
let reader = SmallReader { | ||
contents: contents.clone(), | ||
}; | ||
let mut altout: Vec<u8> = Vec::new(); | ||
let mut teeout = Vec::new(); | ||
{ | ||
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); | ||
tee.read_to_end(&mut teeout).await.unwrap(); | ||
} | ||
assert_eq!(teeout, altout); | ||
assert_eq!(altout.len(), contents.len()); | ||
} |