From e9520fe583372923a93c0b43068024b9cbc124bb Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Tue, 20 Sep 2022 10:53:47 +0100 Subject: [PATCH] tokio_util: add inspection wrapper for `AsyncWrite` When writing things out, it's useful to be able to inspect the bytes that are being written and do things like hash them as they go past. This isn't trivial to get right, due to partial writes and efficiently handling vectored writes (if used). Provide an `InspectWriter` wrapper that gets this right, giving a supplied `FnMut` closure a chance to inspect the buffers that have been successfully written out. Fixes: #4584 --- tokio-util/src/io/inspect.rs | 73 ++++++++++++++- tokio-util/src/io/mod.rs | 2 +- tokio-util/tests/io_inspect.rs | 156 ++++++++++++++++++++++++++++++++- 3 files changed, 226 insertions(+), 5 deletions(-) diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index 8fbd50f3513..53b1d5504ad 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -1,11 +1,11 @@ use futures_core::ready; use pin_project_lite::pin_project; use std::{ - io::Result, + io::{IoSlice, Result}, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// An adapter that lets you inspect the data that's being read. @@ -44,3 +44,72 @@ impl AsyncRead for InspectReader { Poll::Ready(Ok(())) } } + +pin_project! { + /// An adapter that lets you inspect the data that's being written. + /// + /// This is useful for things like hashing data as it's written out. + pub struct InspectWriter { + #[pin] + writer: W, + f: F, + } +} + +impl InspectWriter { + /// Create a new InspectWriter, wrapping `write` and calling `f` for the + /// data successfully written by each write call. + pub fn new(writer: W, f: F) -> InspectWriter { + InspectWriter { writer, f } + } + + /// Consumes the `InspectWriter`, returning the wrapped writer + pub fn into_inner(self) -> W { + self.writer + } +} + +impl AsyncWrite for InspectWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write(cx, buf); + if let Poll::Ready(Ok(count)) = res { + (me.f)(&buf[..count]); + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(mut count)) = res { + for buf in bufs { + let size = count.min(buf.len()); + (me.f)(&buf[..size]); + count -= size; + if count == 0 { + break; + } + } + } + res + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index df0dd02fdc0..317d93b3640 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -20,7 +20,7 @@ cfg_io_util! { pub use self::sync_bridge::SyncIoBridge; } -pub use self::inspect::InspectReader; +pub use self::inspect::{InspectReader, InspectWriter}; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs index aa907c0c148..5da13f57cd6 100644 --- a/tokio-util/tests/io_inspect.rs +++ b/tokio-util/tests/io_inspect.rs @@ -1,9 +1,11 @@ +use futures::future::poll_fn; use std::{ + io::IoSlice, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; -use tokio_util::io::InspectReader; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_util::io::{InspectReader, InspectWriter}; /// An AsyncRead implementation that works byte-by-byte, to catch out callers /// who don't allow for `buf` being part-filled before the call @@ -41,3 +43,153 @@ async fn read_tee() { assert_eq!(teeout, altout); assert_eq!(altout.len(), contents.len()); } + +/// An AsyncWrite implementation that works byte-by-byte for poll_write, and +/// that reads the whole of the first buffer plus one byte from the second in +/// poll_write_vectored. +/// +/// This is designed to catch bugs in handling partially written buffers +#[derive(Debug)] +struct SmallWriter { + contents: Vec, +} + +impl Unpin for SmallWriter {} + +impl AsyncWrite for SmallWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Just write one byte at a time + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.contents.push(buf[0]); + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // Write all of the first buffer, then one byte from the second buffer + // This should trip up anything that doesn't correctly handle multiple + // buffers. + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + let mut written_len = bufs[0].len(); + self.contents.extend_from_slice(&bufs[0]); + + if bufs.len() > 1 { + let buf = bufs[1]; + if !buf.is_empty() { + written_len += 1; + self.contents.push(buf[0]); + } + } + Poll::Ready(Ok(written_len)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[tokio::test] +async fn write_tee() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + { + let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); + tee.write_all(b"A testing string, very testing") + .await + .unwrap(); + } + assert_eq!(altout, writeout.contents); +} + +// This is inefficient, but works well enough for test use. +// If you want something similar for real code, you'll want to avoid all the +// fun of manipulating `bufs` - ideally, by the time you read this, +// IoSlice::advance_slices will be stable, and you can use that. +async fn write_all_vectored( + mut writer: W, + mut bufs: Vec>, +) -> Result { + let mut res = 0; + while !bufs.is_empty() { + let mut written = poll_fn(|cx| { + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(&v)).collect(); + Pin::new(&mut writer).poll_write_vectored(cx, &bufs) + }) + .await?; + res += written; + while written > 0 { + let buf_len = bufs[0].len(); + if buf_len <= written { + bufs.remove(0); + written -= buf_len; + } else { + let buf = &mut bufs[0]; + while written > 0 { + buf.remove(0); + written -= 1; + } + } + } + } + Ok(res) +} + +#[tokio::test] +async fn write_tee_vectored() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + let original = b"A very long string split up"; + let bufs: Vec> = original + .split(|b| b.is_ascii_whitespace()) + .map(Vec::from) + .collect(); + assert!(bufs.len() > 1); + let expected: Vec = { + let mut out = Vec::new(); + for item in &bufs { + out.extend_from_slice(item) + } + out + }; + { + let mut bufcount = 0; + let tee = InspectWriter::new(&mut writeout, |bytes| { + bufcount += 1; + altout.extend(bytes) + }); + + assert!(tee.is_write_vectored()); + + write_all_vectored(tee, bufs.clone()).await.unwrap(); + + assert!(bufcount >= bufs.len()); + } + assert_eq!(altout, writeout.contents); + assert_eq!(writeout.contents, expected); +}