diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs new file mode 100644 index 00000000000..ec5bb97e61c --- /dev/null +++ b/tokio-util/src/io/inspect.rs @@ -0,0 +1,134 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::io::{IoSlice, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pin_project! { + /// An adapter that lets you inspect the data that's being read. + /// + /// This is useful for things like hashing data as it's read in. + pub struct InspectReader { + #[pin] + reader: R, + f: F, + } +} + +impl InspectReader { + /// Create a new InspectReader, wrapping `reader` and calling `f` for the + /// new data supplied by each read call. + /// + /// The closure will only be called with an empty slice if the inner reader + /// returns without reading data into the buffer. This happens at EOF, or if + /// `poll_read` is called with a zero-size buffer. + pub fn new(reader: R, f: F) -> InspectReader + where + R: AsyncRead, + F: FnMut(&[u8]), + { + InspectReader { reader, f } + } + + /// Consumes the `InspectReader`, returning the wrapped reader + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for InspectReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + let filled_length = buf.filled().len(); + ready!(me.reader.poll_read(cx, buf))?; + (me.f)(&buf.filled()[filled_length..]); + Poll::Ready(Ok(())) + } +} + +pin_project! { + /// An adapter that lets you inspect the data that's being written. + /// + /// This is useful for things like hashing data as it's written out. + pub struct InspectWriter { + #[pin] + writer: W, + f: F, + } +} + +impl InspectWriter { + /// Create a new InspectWriter, wrapping `write` and calling `f` for the + /// data successfully written by each write call. + /// + /// The closure `f` will never be called with an empty slice. A vectored + /// write can result in multiple calls to `f` - at most one call to `f` per + /// buffer supplied to `poll_write_vectored`. + pub fn new(writer: W, f: F) -> InspectWriter + where + W: AsyncWrite, + F: FnMut(&[u8]), + { + InspectWriter { writer, f } + } + + /// Consumes the `InspectWriter`, returning the wrapped writer + pub fn into_inner(self) -> W { + self.writer + } +} + +impl AsyncWrite for InspectWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write(cx, buf); + if let Poll::Ready(Ok(count)) = res { + if count != 0 { + (me.f)(&buf[..count]); + } + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(mut count)) = res { + for buf in bufs { + if count == 0 { + break; + } + let size = count.min(buf.len()); + if size != 0 { + (me.f)(&buf[..size]); + count -= size; + } + } + } + res + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index eb48a21fb98..317d93b3640 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -10,14 +10,17 @@ //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead +mod inspect; mod read_buf; mod reader_stream; mod stream_reader; + cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } +pub use self::inspect::{InspectReader, InspectWriter}; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs new file mode 100644 index 00000000000..e6319afcf1b --- /dev/null +++ b/tokio-util/tests/io_inspect.rs @@ -0,0 +1,194 @@ +use futures::future::poll_fn; +use std::{ + io::IoSlice, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_util::io::{InspectReader, InspectWriter}; + +/// An AsyncRead implementation that works byte-by-byte, to catch out callers +/// who don't allow for `buf` being part-filled before the call +struct SmallReader { + contents: Vec, +} + +impl Unpin for SmallReader {} + +impl AsyncRead for SmallReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(byte) = self.contents.pop() { + buf.put_slice(&[byte]) + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_tee() { + let contents = b"This could be really long, you know".to_vec(); + let reader = SmallReader { + contents: contents.clone(), + }; + let mut altout: Vec = Vec::new(); + let mut teeout = Vec::new(); + { + let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); + tee.read_to_end(&mut teeout).await.unwrap(); + } + assert_eq!(teeout, altout); + assert_eq!(altout.len(), contents.len()); +} + +/// An AsyncWrite implementation that works byte-by-byte for poll_write, and +/// that reads the whole of the first buffer plus one byte from the second in +/// poll_write_vectored. +/// +/// This is designed to catch bugs in handling partially written buffers +#[derive(Debug)] +struct SmallWriter { + contents: Vec, +} + +impl Unpin for SmallWriter {} + +impl AsyncWrite for SmallWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Just write one byte at a time + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.contents.push(buf[0]); + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // Write all of the first buffer, then one byte from the second buffer + // This should trip up anything that doesn't correctly handle multiple + // buffers. + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + let mut written_len = bufs[0].len(); + self.contents.extend_from_slice(&bufs[0]); + + if bufs.len() > 1 { + let buf = bufs[1]; + if !buf.is_empty() { + written_len += 1; + self.contents.push(buf[0]); + } + } + Poll::Ready(Ok(written_len)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[tokio::test] +async fn write_tee() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + { + let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); + tee.write_all(b"A testing string, very testing") + .await + .unwrap(); + } + assert_eq!(altout, writeout.contents); +} + +// This is inefficient, but works well enough for test use. +// If you want something similar for real code, you'll want to avoid all the +// fun of manipulating `bufs` - ideally, by the time you read this, +// IoSlice::advance_slices will be stable, and you can use that. +async fn write_all_vectored( + mut writer: W, + mut bufs: Vec>, +) -> Result { + let mut res = 0; + while !bufs.is_empty() { + let mut written = poll_fn(|cx| { + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(v)).collect(); + Pin::new(&mut writer).poll_write_vectored(cx, &bufs) + }) + .await?; + res += written; + while written > 0 { + let buf_len = bufs[0].len(); + if buf_len <= written { + bufs.remove(0); + written -= buf_len; + } else { + let buf = &mut bufs[0]; + let drain_len = written.min(buf.len()); + buf.drain(..drain_len); + written -= drain_len; + } + } + } + Ok(res) +} + +#[tokio::test] +async fn write_tee_vectored() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + let original = b"A very long string split up"; + let bufs: Vec> = original + .split(|b| b.is_ascii_whitespace()) + .map(Vec::from) + .collect(); + assert!(bufs.len() > 1); + let expected: Vec = { + let mut out = Vec::new(); + for item in &bufs { + out.extend_from_slice(item) + } + out + }; + { + let mut bufcount = 0; + let tee = InspectWriter::new(&mut writeout, |bytes| { + bufcount += 1; + altout.extend(bytes) + }); + + assert!(tee.is_write_vectored()); + + write_all_vectored(tee, bufs.clone()).await.unwrap(); + + assert!(bufcount >= bufs.len()); + } + assert_eq!(altout, writeout.contents); + assert_eq!(writeout.contents, expected); +}