Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: restore TcpStream::{poll_read_ready, poll_write_ready} #2743

Merged
merged 5 commits into from Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions tokio/src/net/tcp/stream.rs
Expand Up @@ -356,6 +356,17 @@ impl TcpStream {
Ok(())
}

/// Polls for read readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`readable`] is not feasible. Where possible, using [`readable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`readable`]: method@Self::readable
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_read_ready(cx).map_ok(|_| ())
}

/// Try to read data from the stream into the provided buffer, returning how
/// many bytes were read.
///
Expand Down Expand Up @@ -467,6 +478,17 @@ impl TcpStream {
Ok(())
}

/// Polls for write readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`writable`] is not feasible. Where possible, using [`writable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`writable`]: method@Self::writable
pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_write_ready(cx).map_ok(|_| ())
}

/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///
Expand Down
112 changes: 110 additions & 2 deletions tokio/tests/tcp_stream.rs
@@ -1,12 +1,16 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::Interest;
use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
use tokio::net::{TcpListener, TcpStream};
use tokio::try_join;
use tokio_test::task;
use tokio_test::{assert_pending, assert_ready_ok};
use tokio_test::{assert_ok, assert_pending, assert_ready_ok};

use std::io;
use std::task::Poll;

use futures::future::poll_fn;

#[tokio::test]
async fn try_read_write() {
Expand Down Expand Up @@ -110,3 +114,107 @@ fn buffer_not_included_in_future() {
let n = mem::size_of_val(&fut);
assert!(n < 1000);
}

macro_rules! assert_readable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
};
}

macro_rules! assert_not_readable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_read_ready(cx));
Poll::Ready(())
})
.await;
};
}

macro_rules! assert_writable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
};
}

macro_rules! assert_not_writable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_write_ready(cx));
Poll::Ready(())
})
.await;
};
}

#[tokio::test]
async fn poll_read_ready() {
let (mut client, mut server) = create_pair().await;

// Initial state - not readable.
assert_not_readable_by_polling!(server);

// There is data in the buffer - readable.
assert_ok!(client.write_all(b"ping").await);
assert_readable_by_polling!(server);

// Readable until calls to `poll_read` return `Poll::Pending`.
let mut buf = [0u8; 4];
assert_ok!(server.read_exact(&mut buf).await);
assert_readable_by_polling!(server);
read_until_pending(&mut server);
assert_not_readable_by_polling!(server);

// Detect the client disconnect.
drop(client);
assert_readable_by_polling!(server);
}

#[tokio::test]
async fn poll_write_ready() {
let (mut client, server) = create_pair().await;

// Initial state - writable.
assert_writable_by_polling!(client);

// No space to write - not writable.
write_until_pending(&mut client);
assert_not_writable_by_polling!(client);

// Detect the server disconnect.
drop(server);
assert_writable_by_polling!(client);
}

async fn create_pair() -> (TcpStream, TcpStream) {
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());
let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
(client, server)
}

fn read_until_pending(stream: &mut TcpStream) {
let mut buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_read(&mut buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}

fn write_until_pending(stream: &mut TcpStream) {
let buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_write(&buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}