diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index 8fbd50f3513..53b1d5504ad 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -1,11 +1,11 @@ use futures_core::ready; use pin_project_lite::pin_project; use std::{ - io::Result, + io::{IoSlice, Result}, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// An adapter that lets you inspect the data that's being read. @@ -44,3 +44,72 @@ impl AsyncRead for InspectReader { Poll::Ready(Ok(())) } } + +pin_project! { + /// An adapter that lets you inspect the data that's being written. + /// + /// This is useful for things like hashing data as it's written out. + pub struct InspectWriter { + #[pin] + writer: W, + f: F, + } +} + +impl InspectWriter { + /// Create a new InspectWriter, wrapping `write` and calling `f` for the + /// data successfully written by each write call. + pub fn new(writer: W, f: F) -> InspectWriter { + InspectWriter { writer, f } + } + + /// Consumes the `InspectWriter`, returning the wrapped writer + pub fn into_inner(self) -> W { + self.writer + } +} + +impl AsyncWrite for InspectWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write(cx, buf); + if let Poll::Ready(Ok(count)) = res { + (me.f)(&buf[..count]); + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(mut count)) = res { + for buf in bufs { + let size = count.min(buf.len()); + (me.f)(&buf[..size]); + count -= size; + if count == 0 { + break; + } + } + } + res + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index df0dd02fdc0..317d93b3640 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -20,7 +20,7 @@ cfg_io_util! { pub use self::sync_bridge::SyncIoBridge; } -pub use self::inspect::InspectReader; +pub use self::inspect::{InspectReader, InspectWriter}; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs index aa907c0c148..5da13f57cd6 100644 --- a/tokio-util/tests/io_inspect.rs +++ b/tokio-util/tests/io_inspect.rs @@ -1,9 +1,11 @@ +use futures::future::poll_fn; use std::{ + io::IoSlice, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; -use tokio_util::io::InspectReader; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_util::io::{InspectReader, InspectWriter}; /// An AsyncRead implementation that works byte-by-byte, to catch out callers /// who don't allow for `buf` being part-filled before the call @@ -41,3 +43,153 @@ async fn read_tee() { assert_eq!(teeout, altout); assert_eq!(altout.len(), contents.len()); } + +/// An AsyncWrite implementation that works byte-by-byte for poll_write, and +/// that reads the whole of the first buffer plus one byte from the second in +/// poll_write_vectored. +/// +/// This is designed to catch bugs in handling partially written buffers +#[derive(Debug)] +struct SmallWriter { + contents: Vec, +} + +impl Unpin for SmallWriter {} + +impl AsyncWrite for SmallWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Just write one byte at a time + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.contents.push(buf[0]); + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // Write all of the first buffer, then one byte from the second buffer + // This should trip up anything that doesn't correctly handle multiple + // buffers. + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + let mut written_len = bufs[0].len(); + self.contents.extend_from_slice(&bufs[0]); + + if bufs.len() > 1 { + let buf = bufs[1]; + if !buf.is_empty() { + written_len += 1; + self.contents.push(buf[0]); + } + } + Poll::Ready(Ok(written_len)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[tokio::test] +async fn write_tee() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + { + let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); + tee.write_all(b"A testing string, very testing") + .await + .unwrap(); + } + assert_eq!(altout, writeout.contents); +} + +// This is inefficient, but works well enough for test use. +// If you want something similar for real code, you'll want to avoid all the +// fun of manipulating `bufs` - ideally, by the time you read this, +// IoSlice::advance_slices will be stable, and you can use that. +async fn write_all_vectored( + mut writer: W, + mut bufs: Vec>, +) -> Result { + let mut res = 0; + while !bufs.is_empty() { + let mut written = poll_fn(|cx| { + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(&v)).collect(); + Pin::new(&mut writer).poll_write_vectored(cx, &bufs) + }) + .await?; + res += written; + while written > 0 { + let buf_len = bufs[0].len(); + if buf_len <= written { + bufs.remove(0); + written -= buf_len; + } else { + let buf = &mut bufs[0]; + while written > 0 { + buf.remove(0); + written -= 1; + } + } + } + } + Ok(res) +} + +#[tokio::test] +async fn write_tee_vectored() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + let original = b"A very long string split up"; + let bufs: Vec> = original + .split(|b| b.is_ascii_whitespace()) + .map(Vec::from) + .collect(); + assert!(bufs.len() > 1); + let expected: Vec = { + let mut out = Vec::new(); + for item in &bufs { + out.extend_from_slice(item) + } + out + }; + { + let mut bufcount = 0; + let tee = InspectWriter::new(&mut writeout, |bytes| { + bufcount += 1; + altout.extend(bytes) + }); + + assert!(tee.is_write_vectored()); + + write_all_vectored(tee, bufs.clone()).await.unwrap(); + + assert!(bufcount >= bufs.len()); + } + assert_eq!(altout, writeout.contents); + assert_eq!(writeout.contents, expected); +}