diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs index 4eefe7b26f5..4019db56ff4 100644 --- a/tokio/src/io/util/mem.rs +++ b/tokio/src/io/util/mem.rs @@ -177,10 +177,8 @@ impl Pipe { waker.wake(); } } -} -impl AsyncRead for Pipe { - fn poll_read( + fn poll_read_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, @@ -204,10 +202,8 @@ impl AsyncRead for Pipe { Poll::Pending } } -} -impl AsyncWrite for Pipe { - fn poll_write( + fn poll_write_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], @@ -228,6 +224,62 @@ impl AsyncWrite for Pipe { } Poll::Ready(Ok(len)) } +} + +impl AsyncRead for Pipe { + cfg_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let coop = ready!(crate::coop::poll_proceed(cx)); + + let ret = self.poll_read_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.poll_read_internal(cx, buf) + } + } +} + +impl AsyncWrite for Pipe { + cfg_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let coop = ready!(crate::coop::poll_proceed(cx)); + + let ret = self.poll_write_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.poll_write_internal(cx, buf) + } + } fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs index 01baa5369c7..a2c2dadfc90 100644 --- a/tokio/tests/io_mem_stream.rs +++ b/tokio/tests/io_mem_stream.rs @@ -100,3 +100,22 @@ async fn max_write_size() { // drop b only after task t1 finishes writing drop(b); } + +#[tokio::test] +async fn duplex_is_cooperative() { + let (mut tx, mut rx) = tokio::io::duplex(1024 * 8); + + tokio::select! { + biased; + + _ = async { + loop { + let buf = [3u8; 4096]; + tx.write_all(&buf).await.unwrap(); + let mut buf = [0u8; 4096]; + rx.read(&mut buf).await.unwrap(); + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}