diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 82cdd4b24d6..5daeed08afd 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -58,3 +58,4 @@ tokio = { path = "../tokio" } tokio-test = { path = "../tokio-test", optional = true } doc-comment = "0.3.1" futures = { version = "0.3.0", features = ["async-await"] } +bytes = "1.0.0" diff --git a/tests-integration/tests/process_stdio.rs b/tests-integration/tests/process_stdio.rs index 3f8ebfbfaa9..3ccb69002d3 100644 --- a/tests-integration/tests/process_stdio.rs +++ b/tests-integration/tests/process_stdio.rs @@ -190,3 +190,54 @@ async fn pipe_from_one_command_to_another() { assert!(second_status.expect("second status").success()); assert!(third_status.expect("third status").success()); } + +#[tokio::test] +async fn vectored_writes() { + use bytes::{Buf, Bytes}; + use std::{io::IoSlice, pin::Pin}; + use tokio::io::AsyncWrite; + + let mut cat = cat().spawn().unwrap(); + let mut stdin = cat.stdin.take().unwrap(); + let are_writes_vectored = stdin.is_write_vectored(); + let mut stdout = cat.stdout.take().unwrap(); + + let write = async { + let mut input = Bytes::from_static(b"hello\n").chain(Bytes::from_static(b"world!\n")); + let mut writes_completed = 0; + + futures::future::poll_fn(|cx| loop { + let mut slices = [IoSlice::new(&[]); 2]; + let vectored = input.chunks_vectored(&mut slices); + if vectored == 0 { + return std::task::Poll::Ready(std::io::Result::Ok(())); + } + let n = futures::ready!(Pin::new(&mut stdin).poll_write_vectored(cx, &slices))?; + writes_completed += 1; + input.advance(n); + }) + .await?; + + drop(stdin); + + std::io::Result::Ok(writes_completed) + }; + + let read = async { + let mut buffer = Vec::with_capacity(6 + 7); + stdout.read_to_end(&mut buffer).await?; + std::io::Result::Ok(buffer) + }; + + let (write, read, status) = future::join3(write, read, cat.wait()).await; + + assert!(status.unwrap().success()); + + let writes_completed = write.unwrap(); + // on unix our small payload should always fit in whatever default sized pipe with a single + // syscall. if multiple are used, then the forwarding does not work, or we are on a platform + // for which the `std` does not support vectored writes. + assert_eq!(writes_completed == 1, are_writes_vectored); + + assert_eq!(&read.unwrap(), b"hello\nworld!\n"); +} diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 240d0d4ad40..dfe9ae34cd3 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -208,7 +208,7 @@ feature! { } } - #[cfg(feature = "net")] + #[cfg(any(feature = "net", feature = "process"))] pub(crate) fn poll_write_vectored<'a>( &'a self, cx: &mut Context<'_>, diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 7e1e75d3112..66e42127717 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -1329,6 +1329,18 @@ impl AsyncWrite for ChildStdin { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } } impl AsyncRead for ChildStdout { diff --git a/tokio/src/process/unix/mod.rs b/tokio/src/process/unix/mod.rs index 9bbc813fe6d..0345083dc09 100644 --- a/tokio/src/process/unix/mod.rs +++ b/tokio/src/process/unix/mod.rs @@ -182,6 +182,10 @@ impl<'a> io::Write for &'a Pipe { fn flush(&mut self) -> io::Result<()> { (&self.fd).flush() } + + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + (&self.fd).write_vectored(bufs) + } } impl AsRawFd for Pipe { @@ -258,6 +262,18 @@ impl AsyncWrite for ChildStdio { fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } } impl AsyncRead for ChildStdio {