Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: make duplex stream cooperative (#4470) #4478

Merged
merged 6 commits into from Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 58 additions & 6 deletions tokio/src/io/util/mem.rs
Expand Up @@ -177,10 +177,8 @@ impl Pipe {
waker.wake();
}
}
}

impl AsyncRead for Pipe {
fn poll_read(
fn poll_read_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
Expand All @@ -204,10 +202,8 @@ impl AsyncRead for Pipe {
Poll::Pending
}
}
}

impl AsyncWrite for Pipe {
fn poll_write(
fn poll_write_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
Expand All @@ -228,6 +224,62 @@ impl AsyncWrite for Pipe {
}
Poll::Ready(Ok(len))
}
}

impl AsyncRead for Pipe {
cfg_coop! {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let coop = ready!(crate::coop::poll_proceed(cx));

let ret = self.poll_read_internal(cx, buf);
if ret.is_ready() {
coop.made_progress();
}
ret
}
}

cfg_not_coop! {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.poll_read_internal(cx, buf)
}
}
}

impl AsyncWrite for Pipe {
cfg_coop! {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let coop = ready!(crate::coop::poll_proceed(cx));

let ret = self.poll_write_internal(cx, buf);
if ret.is_ready() {
coop.made_progress();
}
ret
}
}

cfg_not_coop! {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.poll_write_internal(cx, buf)
}
}

fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
Expand Down
19 changes: 19 additions & 0 deletions tokio/tests/io_mem_stream.rs
Expand Up @@ -100,3 +100,22 @@ async fn max_write_size() {
// drop b only after task t1 finishes writing
drop(b);
}

#[tokio::test]
async fn duplex_is_cooperative() {
let (mut tx, mut rx) = tokio::io::duplex(1024 * 8);

tokio::select! {
biased;

_ = async {
loop {
let buf = [3u8; 4096];
tx.write_all(&buf).await.unwrap();
let mut buf = [0u8; 4096];
rx.read(&mut buf).await.unwrap();
}
} => {},
_ = tokio::task::yield_now() => {}
}
}