Skip to content

Commit

Permalink
io: wake pending writers on DuplexStream close (#3756)
Browse files Browse the repository at this point in the history
  • Loading branch information
PiMaker committed May 6, 2021
1 parent 177522c commit d4075a4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
24 changes: 21 additions & 3 deletions tokio/src/io/util/mem.rs
Expand Up @@ -16,6 +16,14 @@ use std::{
/// that can be used as in-memory IO types. Writing to one of the pairs will
/// allow that data to be read from the other, and vice versa.
///
/// # Closing a `DuplexStream`
///
/// If one end of the `DuplexStream` channel is dropped, any pending reads on
/// the other side will continue to read data until the buffer is drained, then
/// they will signal EOF by returning 0 bytes. Any writes to the other side,
/// including pending ones (that are waiting for free space in the buffer) will
/// return `Err(BrokenPipe)` immediately.
///
/// # Example
///
/// ```
Expand Down Expand Up @@ -134,7 +142,8 @@ impl AsyncWrite for DuplexStream {
impl Drop for DuplexStream {
fn drop(&mut self) {
// notify the other side of the closure
self.write.lock().close();
self.write.lock().close_write();
self.read.lock().close_read();
}
}

Expand All @@ -151,12 +160,21 @@ impl Pipe {
}
}

fn close(&mut self) {
fn close_write(&mut self) {
self.is_closed = true;
// needs to notify any readers that no more data will come
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}

fn close_read(&mut self) {
self.is_closed = true;
// needs to notify any writers that they have to abort
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
}

impl AsyncRead for Pipe {
Expand Down Expand Up @@ -217,7 +235,7 @@ impl AsyncWrite for Pipe {
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close();
self.close_write();
Poll::Ready(Ok(()))
}
}
29 changes: 24 additions & 5 deletions tokio/tests/io_mem_stream.rs
Expand Up @@ -62,6 +62,25 @@ async fn disconnect() {
t2.await.unwrap();
}

#[tokio::test]
async fn disconnect_reader() {
let (a, mut b) = duplex(2);

let t1 = tokio::spawn(async move {
// this will block, as not all data fits into duplex
b.write_all(b"ping").await.unwrap_err();
});

let t2 = tokio::spawn(async move {
// here we drop the reader side, and we expect the writer in the other
// task to exit with an error
drop(a);
});

t2.await.unwrap();
t1.await.unwrap();
}

#[tokio::test]
async fn max_write_size() {
let (mut a, mut b) = duplex(32);
Expand All @@ -73,11 +92,11 @@ async fn max_write_size() {
assert_eq!(n, 4);
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();
});
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();

t1.await.unwrap();
t2.await.unwrap();

// drop b only after task t1 finishes writing
drop(b);
}

0 comments on commit d4075a4

Please sign in to comment.