Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: add a copy_bidirectional utility #3572

Merged
merged 19 commits into from Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/src/io/mod.rs
Expand Up @@ -246,7 +246,7 @@ cfg_io_util! {
pub(crate) mod seek;
pub(crate) mod util;
pub use util::{
copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
copy, copy_bidirectional, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take,
};
}
Expand Down
128 changes: 78 additions & 50 deletions tokio/src/io/util/copy.rs
Expand Up @@ -5,18 +5,85 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}

impl CopyBuffer {
pub(super) fn new() -> Self {
Self {
read_done: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; 2048].into_boxed_slice(),
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(reader.as_mut().poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
Comment on lines +55 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to try to read more data if the write returns Pending when me.cap < 2048.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into this

return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
}
}

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}

/// A future that asynchronously copies the entire contents of a reader into a
/// writer.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
read_done: bool,
writer: &'a mut W,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
buf: CopyBuffer,
}

cfg_io_util! {
Expand All @@ -35,8 +102,8 @@ cfg_io_util! {
///
/// # Errors
///
/// The returned future will finish with an error will return an error
/// immediately if any call to `poll_read` or `poll_write` returns an error.
/// The returned future will return an error immediately if any call to
/// `poll_read` or `poll_write` returns an error.
///
/// # Examples
///
Expand All @@ -60,12 +127,8 @@ cfg_io_util! {
{
Copy {
reader,
read_done: false,
writer,
amt: 0,
pos: 0,
cap: 0,
buf: vec![0; 2048].into_boxed_slice(),
buf: CopyBuffer::new()
}.await
}
}
Expand All @@ -78,44 +141,9 @@ where
type Output = io::Result<u64>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}
let me = &mut *self;

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
}
}

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
let me = &mut *self;
ready!(Pin::new(&mut *me.writer).poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
me.buf
.poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
}
}
119 changes: 119 additions & 0 deletions tokio/src/io/util/copy_bidirectional.rs
@@ -0,0 +1,119 @@
use super::copy::CopyBuffer;

use crate::io::{AsyncRead, AsyncWrite};

use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}

struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}

fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);

loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;

*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}

impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional {
a,
b,
a_to_b,
b_to_a,
} = &mut *self;

let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;

// It is not a problem if ready! returns early because transfer_one_direction for the
// other direction will keep returning TransferState::Done(count) in future calls to poll
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);

Poll::Ready(Ok((a_to_b, b_to_a)))
}
}

/// Copies data in both directions between `a` and `b`.
///
/// This function returns a future that will read from both streams,
/// writing any data read to the opposing stream.
/// This happens in both directions concurrently.
///
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
/// the other, and reading from that stream will stop. Copying of data in
/// the other direction will continue.
///
/// The future will complete successfully once both directions of communication has been shut down.
/// A direction is shut down when the reader reports EOF,
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
/// it will return a tuple of the number of bytes copied from a to b
/// and the number of bytes copied from b to a, in that order.
///
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
///
/// # Errors
///
/// The future will immediately return an error if any IO operation on `a`
/// or `b` returns an error. Some data read from either stream may be lost (not
/// written to the other stream) in this case.
///
/// # Return value
///
/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new()),
b_to_a: TransferState::Running(CopyBuffer::new()),
}
.await
}
3 changes: 3 additions & 0 deletions tokio/src/io/util/mod.rs
Expand Up @@ -27,6 +27,9 @@ cfg_io_util! {
mod copy;
pub use copy::copy;

mod copy_bidirectional;
pub use copy_bidirectional::copy_bidirectional;

mod copy_buf;
pub use copy_buf::copy_buf;

Expand Down