Skip to content

Commit

Permalink
net: restore TcpStream::{poll_read_ready, poll_write_ready} (#2743)
Browse files Browse the repository at this point in the history
  • Loading branch information
masnagam committed Nov 16, 2020
1 parent 97c2c42 commit 4e39c9b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 2 deletions.
22 changes: 22 additions & 0 deletions tokio/src/net/tcp/stream.rs
Expand Up @@ -356,6 +356,17 @@ impl TcpStream {
Ok(())
}

/// Polls for read readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`readable`] is not feasible. Where possible, using [`readable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`readable`]: method@Self::readable
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_read_ready(cx).map_ok(|_| ())
}

/// Try to read data from the stream into the provided buffer, returning how
/// many bytes were read.
///
Expand Down Expand Up @@ -467,6 +478,17 @@ impl TcpStream {
Ok(())
}

/// Polls for write readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`writable`] is not feasible. Where possible, using [`writable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`writable`]: method@Self::writable
pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_write_ready(cx).map_ok(|_| ())
}

/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///
Expand Down
112 changes: 110 additions & 2 deletions tokio/tests/tcp_stream.rs
@@ -1,12 +1,16 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::Interest;
use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
use tokio::net::{TcpListener, TcpStream};
use tokio::try_join;
use tokio_test::task;
use tokio_test::{assert_pending, assert_ready_ok};
use tokio_test::{assert_ok, assert_pending, assert_ready_ok};

use std::io;
use std::task::Poll;

use futures::future::poll_fn;

#[tokio::test]
async fn try_read_write() {
Expand Down Expand Up @@ -110,3 +114,107 @@ fn buffer_not_included_in_future() {
let n = mem::size_of_val(&fut);
assert!(n < 1000);
}

macro_rules! assert_readable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
};
}

macro_rules! assert_not_readable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_read_ready(cx));
Poll::Ready(())
})
.await;
};
}

macro_rules! assert_writable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
};
}

macro_rules! assert_not_writable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_write_ready(cx));
Poll::Ready(())
})
.await;
};
}

#[tokio::test]
async fn poll_read_ready() {
let (mut client, mut server) = create_pair().await;

// Initial state - not readable.
assert_not_readable_by_polling!(server);

// There is data in the buffer - readable.
assert_ok!(client.write_all(b"ping").await);
assert_readable_by_polling!(server);

// Readable until calls to `poll_read` return `Poll::Pending`.
let mut buf = [0u8; 4];
assert_ok!(server.read_exact(&mut buf).await);
assert_readable_by_polling!(server);
read_until_pending(&mut server);
assert_not_readable_by_polling!(server);

// Detect the client disconnect.
drop(client);
assert_readable_by_polling!(server);
}

#[tokio::test]
async fn poll_write_ready() {
let (mut client, server) = create_pair().await;

// Initial state - writable.
assert_writable_by_polling!(client);

// No space to write - not writable.
write_until_pending(&mut client);
assert_not_writable_by_polling!(client);

// Detect the server disconnect.
drop(server);
assert_writable_by_polling!(client);
}

async fn create_pair() -> (TcpStream, TcpStream) {
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());
let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
(client, server)
}

fn read_until_pending(stream: &mut TcpStream) {
let mut buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_read(&mut buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}

fn write_until_pending(stream: &mut TcpStream) {
let buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_write(&buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}

0 comments on commit 4e39c9b

Please sign in to comment.