Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: change AsyncRead to use a ReadBuf #2758

Merged
merged 11 commits into from Aug 14, 2020
42 changes: 24 additions & 18 deletions tokio-test/src/io.rs
Expand Up @@ -18,7 +18,7 @@
//! [`AsyncRead`]: tokio::io::AsyncRead
//! [`AsyncWrite`]: tokio::io::AsyncWrite

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Delay, Duration, Instant};

Expand Down Expand Up @@ -204,20 +204,20 @@ impl Inner {
self.rx.poll_recv(cx)
}

fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
// Figure out how much to copy
let n = cmp::min(dst.len(), data.len());
let n = cmp::min(dst.remaining(), data.len());

// Copy the data into the `dst` slice
(&mut dst[..n]).copy_from_slice(&data[..n]);
dst.append(&data[..n]);

// Drain the data from the source
data.drain(..n);

// Return the number of bytes read
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
Ok(n)
Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
// As the
Expand All @@ -229,7 +229,7 @@ impl Inner {
// Either waiting or expecting a write
Err(io::ErrorKind::WouldBlock.into())
}
None => Ok(0),
None => Ok(()),
}
}

Expand Down Expand Up @@ -348,8 +348,8 @@ impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
Expand All @@ -358,6 +358,9 @@ impl AsyncRead for Mock {
// If a sleep is set, it has already fired
self.inner.sleep = None;

// Capture 'filled' to monitor if it changed
let filled = buf.filled().len();

match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
Expand All @@ -368,19 +371,22 @@ impl AsyncRead for Mock {
return Poll::Pending;
}
}
Ok(0) => {
// TODO: Extract
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(0));
Ok(()) => {
if buf.filled().len() == filled {
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(()));
}
}
} else {
return Poll::Ready(Ok(()));
}
}
hawkw marked this conversation as resolved.
Show resolved Hide resolved
ret => return Poll::Ready(ret),
Err(e) => return Poll::Ready(Err(e)),
}
}
}
Expand Down
19 changes: 7 additions & 12 deletions tokio/src/fs/file.rs
Expand Up @@ -5,7 +5,7 @@
use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite};
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use std::fmt;
use std::fs::{Metadata, Permissions};
Expand Down Expand Up @@ -537,25 +537,20 @@ impl File {
}

impl AsyncRead for File {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
// https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668
false
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

if !buf.is_empty() {
let n = buf.copy_to(dst);
buf.copy_to(dst);
*buf_cell = Some(buf);
return Ready(Ok(n));
return Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
Expand All @@ -571,9 +566,9 @@ impl AsyncRead for File {

match op {
Operation::Read(Ok(_)) => {
let n = buf.copy_to(dst);
buf.copy_to(dst);
self.state = Idle(Some(buf));
return Ready(Ok(n));
return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());
Expand Down
118 changes: 37 additions & 81 deletions tokio/src/io/async_read.rs
@@ -1,6 +1,6 @@
use super::ReadBuf;
use bytes::BufMut;
use std::io;
use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -41,47 +41,6 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
/// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
/// `true` if the supplied buffer was zeroed out.
///
/// While it would be highly unusual, implementations of [`io::Read`] are
/// able to read data from the buffer passed as an argument. Because of
/// this, the buffer passed to [`io::Read`] must be initialized memory. In
/// situations where large numbers of buffers are used, constantly having to
/// zero out buffers can be expensive.
///
/// This function does any necessary work to prepare an uninitialized buffer
/// to be safe to pass to `read`. If `read` guarantees to never attempt to
/// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
/// doesn't need to do any work.
///
/// If this function returns `true`, then the memory has been zeroed out.
/// This allows implementations of `AsyncRead` which are composed of
/// multiple subimplementations to efficiently implement
/// `prepare_uninitialized_buffer`.
///
/// This function isn't actually `unsafe` to call but `unsafe` to implement.
/// The implementer must ensure that either the whole `buf` has been zeroed
/// or `poll_read_buf()` overwrites the buffer without reading it and returns
/// correct value.
///
/// This function is called from [`poll_read_buf`].
///
/// # Safety
///
/// Implementations that return `false` must never read from data slices
/// that they did not write to.
///
/// [`io::Read`]: std::io::Read
/// [`poll_read_buf`]: method@Self::poll_read_buf
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
for x in buf {
*x = MaybeUninit::new(0);
}

true
}

/// Attempts to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
Expand All @@ -93,8 +52,8 @@ pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>>;

/// Pulls some bytes from this source into the specified `BufMut`, returning
/// how many bytes were read.
Expand All @@ -116,16 +75,10 @@ pub trait AsyncRead {

unsafe {
let n = {
let b = buf.bytes_mut();

self.prepare_uninitialized_buffer(b);

// Convert to `&mut [u8]`
let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
let mut b = ReadBuf::uninit(buf.bytes_mut());

let n = ready!(self.poll_read(cx, b))?;
assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
n
ready!(self.poll_read(cx, &mut b))?;
b.filled().len()
};

buf.advance_mut(n);
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -136,15 +89,11 @@ pub trait AsyncRead {

macro_rules! deref_async_read {
() => {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
(**self).prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
Expand All @@ -163,43 +112,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
(**self).prepare_uninitialized_buffer(buf)
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}

impl AsyncRead for &[u8] {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(io::Read::read(self.get_mut(), buf))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let amt = std::cmp::min(self.len(), buf.remaining());
let (a, b) = self.split_at(amt);
buf.append(a);
*self = b;
Poll::Ready(Ok(()))
}
}

impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(io::Read::read(self.get_mut(), buf))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();

// The position could technically be out of bounds, so don't panic...
if pos > slice.len() as u64 {
return Poll::Ready(Ok(()));
}

let start = pos as usize;
let amt = std::cmp::min(slice.len() - start, buf.remaining());
// Add won't overflow because of pos check above.
let end = start + amt;
buf.append(&slice[start..end]);
self.set_position(end as u64);

Poll::Ready(Ok(()))
}
}
24 changes: 12 additions & 12 deletions tokio/src/io/blocking.rs
@@ -1,5 +1,5 @@
use crate::io::sys;
use crate::io::{AsyncRead, AsyncWrite};
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::cmp;
use std::future::Future;
Expand Down Expand Up @@ -53,17 +53,17 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

if !buf.is_empty() {
let n = buf.copy_to(dst);
buf.copy_to(dst);
*buf_cell = Some(buf);
return Ready(Ok(n));
return Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
Expand All @@ -80,9 +80,9 @@ where

match res {
Ok(_) => {
let n = buf.copy_to(dst);
buf.copy_to(dst);
self.state = Idle(Some(buf));
return Ready(Ok(n));
return Ready(Ok(()));
}
Err(e) => {
assert!(buf.is_empty());
Expand Down Expand Up @@ -203,9 +203,9 @@ impl Buf {
self.buf.len() - self.pos
}

pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize {
let n = cmp::min(self.len(), dst.len());
dst[..n].copy_from_slice(&self.bytes()[..n]);
pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
let n = cmp::min(self.len(), dst.remaining());
dst.append(&self.bytes()[..n]);
self.pos += n;

if self.pos == self.buf.len() {
Expand All @@ -229,10 +229,10 @@ impl Buf {
&self.buf[self.pos..]
}

pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) {
pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
assert!(self.is_empty());

let len = cmp::min(bytes.len(), MAX_BUF);
let len = cmp::min(bytes.remaining(), MAX_BUF);

if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
Expand Down