diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs index e91a9328b88..05655894816 100644 --- a/tokio/src/io/util/mem.rs +++ b/tokio/src/io/util/mem.rs @@ -16,6 +16,14 @@ use std::{ /// that can be used as in-memory IO types. Writing to one of the pairs will /// allow that data to be read from the other, and vice versa. /// +/// # Closing a `DuplexStream` +/// +/// If one end of the `DuplexStream` channel is dropped, any pending reads on +/// the other side will continue to read data until the buffer is drained, then +/// they will signal EOF by returning 0 bytes. Any writes to the other side, +/// including pending ones (that are waiting for free space in the buffer) will +/// return `Err(BrokenPipe)` immediately. +/// /// # Example /// /// ``` @@ -134,7 +142,8 @@ impl AsyncWrite for DuplexStream { impl Drop for DuplexStream { fn drop(&mut self) { // notify the other side of the closure - self.write.lock().close(); + self.write.lock().close_write(); + self.read.lock().close_read(); } } @@ -151,12 +160,21 @@ impl Pipe { } } - fn close(&mut self) { + fn close_write(&mut self) { self.is_closed = true; + // needs to notify any readers that no more data will come if let Some(waker) = self.read_waker.take() { waker.wake(); } } + + fn close_read(&mut self) { + self.is_closed = true; + // needs to notify any writers that they have to abort + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } } impl AsyncRead for Pipe { @@ -217,7 +235,7 @@ impl AsyncWrite for Pipe { mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll> { - self.close(); + self.close_write(); Poll::Ready(Ok(())) } } diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs index 3335214cb9d..01baa5369c7 100644 --- a/tokio/tests/io_mem_stream.rs +++ b/tokio/tests/io_mem_stream.rs @@ -62,6 +62,25 @@ async fn disconnect() { t2.await.unwrap(); } +#[tokio::test] +async fn disconnect_reader() { + let (a, mut b) = duplex(2); + + let t1 = tokio::spawn(async move { + // this will block, as not all data fits into duplex + b.write_all(b"ping").await.unwrap_err(); + }); + + let t2 = tokio::spawn(async move { + // here we drop the reader side, and we expect the writer in the other + // task to exit with an error + drop(a); + }); + + t2.await.unwrap(); + t1.await.unwrap(); +} + #[tokio::test] async fn max_write_size() { let (mut a, mut b) = duplex(32); @@ -73,11 +92,11 @@ async fn max_write_size() { assert_eq!(n, 4); }); - let t2 = tokio::spawn(async move { - let mut buf = [0u8; 4]; - b.read_exact(&mut buf).await.unwrap(); - }); + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); t1.await.unwrap(); - t2.await.unwrap(); + + // drop b only after task t1 finishes writing + drop(b); }