Skip to content

Commit

Permalink
io: change AsyncRead to use a ReadBuf (#2758)
Browse files Browse the repository at this point in the history
Works towards #2716. Changes the argument to `AsyncRead::poll_read` to
take a `ReadBuf` struct that safely manages writes to uninitialized memory.
  • Loading branch information
seanmonstar committed Aug 14, 2020
1 parent 71da060 commit c393236
Show file tree
Hide file tree
Showing 40 changed files with 626 additions and 544 deletions.
43 changes: 24 additions & 19 deletions tokio-test/src/io.rs
Expand Up @@ -18,7 +18,7 @@
//! [`AsyncRead`]: tokio::io::AsyncRead
//! [`AsyncWrite`]: tokio::io::AsyncWrite

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Delay, Duration, Instant};

Expand Down Expand Up @@ -204,20 +204,19 @@ impl Inner {
self.rx.poll_recv(cx)
}

fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
// Figure out how much to copy
let n = cmp::min(dst.len(), data.len());
let n = cmp::min(dst.remaining(), data.len());

// Copy the data into the `dst` slice
(&mut dst[..n]).copy_from_slice(&data[..n]);
dst.append(&data[..n]);

// Drain the data from the source
data.drain(..n);

// Return the number of bytes read
Ok(n)
Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
// As the
Expand All @@ -229,7 +228,7 @@ impl Inner {
// Either waiting or expecting a write
Err(io::ErrorKind::WouldBlock.into())
}
None => Ok(0),
None => Ok(()),
}
}

Expand Down Expand Up @@ -348,8 +347,8 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
Expand All @@ -358,6 +357,9 @@ impl AsyncRead for Mock {
// If a sleep is set, it has already fired
self.inner.sleep = None;

// Capture 'filled' to monitor if it changed
let filled = buf.filled().len();

match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
Expand All @@ -368,19 +370,22 @@ impl AsyncRead for Mock {
return Poll::Pending;
}
}
Ok(0) => {
// TODO: Extract
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(0));
Ok(()) => {
if buf.filled().len() == filled {
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(()));
}
}
} else {
return Poll::Ready(Ok(()));
}
}
ret => return Poll::Ready(ret),
Err(e) => return Poll::Ready(Err(e)),
}
}
}
Expand Down
26 changes: 21 additions & 5 deletions tokio-util/src/compat.rs
@@ -1,5 +1,6 @@
//! Compatibility between the `tokio::io` and `futures-io` versions of the
//! `AsyncRead` and `AsyncWrite` traits.
use futures_core::ready;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
Expand Down Expand Up @@ -107,9 +108,18 @@ where
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
futures_io::AsyncRead::poll_read(self.project().inner, cx, buf)
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// We can't trust the inner type to not peak at the bytes,
// so we must defensively initialize the buffer.
let slice = buf.initialize_unfilled();
let n = ready!(futures_io::AsyncRead::poll_read(
self.project().inner,
cx,
slice
))?;
buf.add_filled(n);
Poll::Ready(Ok(()))
}
}

Expand All @@ -120,9 +130,15 @@ where
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
slice: &mut [u8],
) -> Poll<io::Result<usize>> {
tokio::io::AsyncRead::poll_read(self.project().inner, cx, buf)
let mut buf = tokio::io::ReadBuf::new(slice);
ready!(tokio::io::AsyncRead::poll_read(
self.project().inner,
cx,
&mut buf
))?;
Poll::Ready(Ok(buf.filled().len()))
}
}

Expand Down
4 changes: 2 additions & 2 deletions tokio-util/tests/framed.rs
Expand Up @@ -55,8 +55,8 @@ impl AsyncRead for DontReadIntoThis {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
unreachable!()
}
}
Expand Down
18 changes: 9 additions & 9 deletions tokio-util/tests/framed_read.rs
@@ -1,6 +1,6 @@
#![warn(rust_2018_idioms)]

use tokio::io::AsyncRead;
use tokio::io::{AsyncRead, ReadBuf};
use tokio_test::assert_ready;
use tokio_test::task;
use tokio_util::codec::{Decoder, FramedRead};
Expand Down Expand Up @@ -264,19 +264,19 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
use io::ErrorKind::WouldBlock;

match self.calls.pop_front() {
Some(Ok(data)) => {
debug_assert!(buf.len() >= data.len());
buf[..data.len()].copy_from_slice(&data[..]);
Ready(Ok(data.len()))
debug_assert!(buf.remaining() >= data.len());
buf.append(&data);
Ready(Ok(()))
}
Some(Err(ref e)) if e.kind() == WouldBlock => Pending,
Some(Err(e)) => Ready(Err(e)),
None => Ready(Ok(0)),
None => Ready(Ok(())),
}
}
}
Expand All @@ -288,8 +288,8 @@ impl AsyncRead for Slice<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
14 changes: 7 additions & 7 deletions tokio-util/tests/length_delimited.rs
@@ -1,6 +1,6 @@
#![warn(rust_2018_idioms)]

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_test::task;
use tokio_test::{
assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
Expand Down Expand Up @@ -707,18 +707,18 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.calls.pop_front() {
Some(Ready(Ok(Op::Data(data)))) => {
debug_assert!(dst.len() >= data.len());
dst[..data.len()].copy_from_slice(&data[..]);
Ready(Ok(data.len()))
debug_assert!(dst.remaining() >= data.len());
dst.append(&data);
Ready(Ok(()))
}
Some(Ready(Ok(_))) => panic!(),
Some(Ready(Err(e))) => Ready(Err(e)),
Some(Pending) => Pending,
None => Ready(Ok(0)),
None => Ready(Ok(())),
}
}
}
Expand Down
19 changes: 7 additions & 12 deletions tokio/src/fs/file.rs
Expand Up @@ -5,7 +5,7 @@
use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite};
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use std::fmt;
use std::fs::{Metadata, Permissions};
Expand Down Expand Up @@ -537,25 +537,20 @@ impl File {
}

impl AsyncRead for File {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
// https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668
false
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

if !buf.is_empty() {
let n = buf.copy_to(dst);
buf.copy_to(dst);
*buf_cell = Some(buf);
return Ready(Ok(n));
return Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
Expand All @@ -571,9 +566,9 @@ impl AsyncRead for File {

match op {
Operation::Read(Ok(_)) => {
let n = buf.copy_to(dst);
buf.copy_to(dst);
self.state = Idle(Some(buf));
return Ready(Ok(n));
return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());
Expand Down

0 comments on commit c393236

Please sign in to comment.