diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 3e7c9438ad9..14a4a6304dd 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -246,7 +246,7 @@ cfg_io_util! { pub(crate) mod seek; pub(crate) mod util; pub use util::{ - copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + copy, copy_bidirectional, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take, }; } diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index c5981cf9aa6..3cd425b348b 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -5,18 +5,85 @@ use std::io; use std::pin::Pin; use std::task::{Context, Poll}; +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} + +impl CopyBuffer { + pub(super) fn new() -> Self { + Self { + read_done: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; 2048].into_boxed_slice(), + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + ready!(reader.as_mut().poll_read(cx, &mut buf))?; + let n = buf.filled().len(); + if n == 0 { + self.read_done = true; + } else { + self.pos = 0; + self.cap = n; + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let me = &mut *self; + let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + } + } + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} + /// A future that asynchronously copies the entire contents of a reader into a /// writer. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] struct Copy<'a, R: ?Sized, W: ?Sized> { reader: &'a mut R, - read_done: bool, writer: &'a mut W, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, + buf: CopyBuffer, } cfg_io_util! { @@ -35,8 +102,8 @@ cfg_io_util! { /// /// # Errors /// - /// The returned future will finish with an error will return an error - /// immediately if any call to `poll_read` or `poll_write` returns an error. + /// The returned future will return an error immediately if any call to + /// `poll_read` or `poll_write` returns an error. /// /// # Examples /// @@ -60,12 +127,8 @@ cfg_io_util! { { Copy { reader, - read_done: false, writer, - amt: 0, - pos: 0, - cap: 0, - buf: vec![0; 2048].into_boxed_slice(), + buf: CopyBuffer::new() }.await } } @@ -78,44 +141,9 @@ where type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - // If our buffer is empty, then we need to read some data to - // continue. - if self.pos == self.cap && !self.read_done { - let me = &mut *self; - let mut buf = ReadBuf::new(&mut me.buf); - ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; - let n = buf.filled().len(); - if n == 0 { - self.read_done = true; - } else { - self.pos = 0; - self.cap = n; - } - } + let me = &mut *self; - // If our buffer has some data, let's write it out! - while self.pos < self.cap { - let me = &mut *self; - let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?; - if i == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "write zero byte into writer", - ))); - } else { - self.pos += i; - self.amt += i as u64; - } - } - - // If we've written all the data and we've seen EOF, flush out the - // data and finish the transfer. - if self.pos == self.cap && self.read_done { - let me = &mut *self; - ready!(Pin::new(&mut *me.writer).poll_flush(cx))?; - return Poll::Ready(Ok(self.amt)); - } - } + me.buf + .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer)) } } diff --git a/tokio/src/io/util/copy_bidirectional.rs b/tokio/src/io/util/copy_bidirectional.rs new file mode 100644 index 00000000000..cc43f0fd67d --- /dev/null +++ b/tokio/src/io/util/copy_bidirectional.rs @@ -0,0 +1,119 @@ +use super::copy::CopyBuffer; + +use crate::io::{AsyncRead, AsyncWrite}; + +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> { + a: &'a mut A, + b: &'a mut B, + a_to_b: TransferState, + b_to_a: TransferState, +} + +fn transfer_one_direction( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +impl<'a, A, B> Future for CopyBidirectional<'a, A, B> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<(u64, u64)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Unpack self into mut refs to each field to avoid borrow check issues. + let CopyBidirectional { + a, + b, + a_to_b, + b_to_a, + } = &mut *self; + + let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?; + let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?; + + // It is not a problem if ready! returns early because transfer_one_direction for the + // other direction will keep returning TransferState::Done(count) in future calls to poll + let a_to_b = ready!(a_to_b); + let b_to_a = ready!(b_to_a); + + Poll::Ready(Ok((a_to_b, b_to_a))) + } +} + +/// Copies data in both directions between `a` and `b`. +/// +/// This function returns a future that will read from both streams, +/// writing any data read to the opposing stream. +/// This happens in both directions concurrently. +/// +/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on +/// the other, and reading from that stream will stop. Copying of data in +/// the other direction will continue. +/// +/// The future will complete successfully once both directions of communication has been shut down. +/// A direction is shut down when the reader reports EOF, +/// at which point [`shutdown()`] is called on the corresponding writer. When finished, +/// it will return a tuple of the number of bytes copied from a to b +/// and the number of bytes copied from b to a, in that order. +/// +/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown +/// +/// # Errors +/// +/// The future will immediately return an error if any IO operation on `a` +/// or `b` returns an error. Some data read from either stream may be lost (not +/// written to the other stream) in this case. +/// +/// # Return value +/// +/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`. +pub async fn copy_bidirectional(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + CopyBidirectional { + a, + b, + a_to_b: TransferState::Running(CopyBuffer::new()), + b_to_a: TransferState::Running(CopyBuffer::new()), + } + .await +} diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index 9ddb7758d4d..c39d6dcb8c5 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -27,6 +27,9 @@ cfg_io_util! { mod copy; pub use copy::copy; + mod copy_bidirectional; + pub use copy_bidirectional::copy_bidirectional; + mod copy_buf; pub use copy_buf::copy_buf; diff --git a/tokio/tests/io_copy_bidirectional.rs b/tokio/tests/io_copy_bidirectional.rs new file mode 100644 index 00000000000..17c059725c6 --- /dev/null +++ b/tokio/tests/io_copy_bidirectional.rs @@ -0,0 +1,128 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::time::Duration; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::task::JoinHandle; + +async fn make_socketpair() -> (TcpStream, TcpStream) { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let connector = TcpStream::connect(addr); + let acceptor = listener.accept(); + + let (c1, c2) = tokio::join!(connector, acceptor); + + (c1.unwrap(), c2.unwrap().0) +} + +async fn block_write(s: &mut TcpStream) -> usize { + static BUF: [u8; 2048] = [0; 2048]; + + let mut copied = 0; + loop { + tokio::select! { + result = s.write(&BUF) => { + copied += result.expect("write error") + }, + _ = tokio::time::sleep(Duration::from_millis(100)) => { + break; + } + } + } + + copied +} + +async fn symmetric(mut cb: F) +where + F: FnMut(JoinHandle>, TcpStream, TcpStream) -> Fut, + Fut: std::future::Future, +{ + // We run the test twice, with streams passed to copy_bidirectional in + // different orders, in order to ensure that the two arguments are + // interchangable. + + let (a, mut a1) = make_socketpair().await; + let (b, mut b1) = make_socketpair().await; + + let handle = tokio::spawn(async move { copy_bidirectional(&mut a1, &mut b1).await }); + cb(handle, a, b).await; + + let (a, mut a1) = make_socketpair().await; + let (b, mut b1) = make_socketpair().await; + + let handle = tokio::spawn(async move { copy_bidirectional(&mut b1, &mut a1).await }); + + cb(handle, b, a).await; +} + +#[tokio::test] +async fn test_basic_transfer() { + symmetric(|_handle, mut a, mut b| async move { + a.write_all(b"test").await.unwrap(); + let mut tmp = [0; 4]; + b.read_exact(&mut tmp).await.unwrap(); + assert_eq!(&tmp[..], b"test"); + }) + .await +} + +#[tokio::test] +async fn test_transfer_after_close() { + symmetric(|handle, mut a, mut b| async move { + AsyncWriteExt::shutdown(&mut a).await.unwrap(); + b.read_to_end(&mut Vec::new()).await.unwrap(); + + b.write_all(b"quux").await.unwrap(); + let mut tmp = [0; 4]; + a.read_exact(&mut tmp).await.unwrap(); + assert_eq!(&tmp[..], b"quux"); + + // Once both are closed, we should have our handle back + drop(b); + + assert_eq!(handle.await.unwrap().unwrap(), (0, 4)); + }) + .await +} + +#[tokio::test] +async fn blocking_one_side_does_not_block_other() { + symmetric(|handle, mut a, mut b| async move { + block_write(&mut a).await; + + b.write_all(b"quux").await.unwrap(); + let mut tmp = [0; 4]; + a.read_exact(&mut tmp).await.unwrap(); + assert_eq!(&tmp[..], b"quux"); + + AsyncWriteExt::shutdown(&mut a).await.unwrap(); + + let mut buf = Vec::new(); + b.read_to_end(&mut buf).await.unwrap(); + + drop(b); + + assert_eq!(handle.await.unwrap().unwrap(), (buf.len() as u64, 4)); + }) + .await +} + +#[tokio::test] +async fn immediate_exit_on_error() { + symmetric(|handle, mut a, mut b| async move { + block_write(&mut a).await; + + // Fill up the b->copy->a path. We expect that this will _not_ drain + // before we exit the copy task. + let _bytes_written = block_write(&mut b).await; + + // Drop b. We should not wait for a to consume the data buffered in the + // copy loop, since b will be failing writes. + drop(b); + assert!(handle.await.unwrap().is_err()); + }) + .await +}